{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 329,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#import SparkSession\n",
    "from pyspark.sql import SparkSession\n",
    "#create spar session object\n",
    "spark=SparkSession.builder.appName('supervised_ml').getOrCreate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regression "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 331,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df=spark.read.csv('Linear_regression_dataset.csv',inferSchema=True,header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 332,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.regression import LinearRegression\n",
    "from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 333,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1232, 6)\n"
     ]
    }
   ],
   "source": [
    "print((df.count(), len(df.columns)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----+-----+-----+-----+-----+\n",
      "|var_1|var_2|var_3|var_4|var_5|label|\n",
      "+-----+-----+-----+-----+-----+-----+\n",
      "|  734|  688|   81|0.328|0.259|0.418|\n",
      "|  700|  600|   94| 0.32|0.247|0.389|\n",
      "|  712|  705|   93|0.311|0.247|0.417|\n",
      "|  734|  806|   69|0.315| 0.26|0.415|\n",
      "|  613|  759|   61|0.302| 0.24|0.378|\n",
      "|  748|  676|   85|0.318|0.255|0.422|\n",
      "|  669|  588|   97|0.315|0.251|0.411|\n",
      "|  667|  845|   68|0.324|0.251|0.381|\n",
      "|  758|  890|   64| 0.33|0.274|0.436|\n",
      "|  726|  670|   88|0.335|0.268|0.422|\n",
      "+-----+-----+-----+-----+-----+-----+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 334,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df_assembler = VectorAssembler(inputCols=['var_1', 'var_2', 'var_3', 'var_4', 'var_5'], outputCol=\"features\")\n",
    "df = df_assembler.transform(df) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 335,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+\n",
      "|            features|label|\n",
      "+--------------------+-----+\n",
      "|[734.0,688.0,81.0...|0.418|\n",
      "|[700.0,600.0,94.0...|0.389|\n",
      "|[712.0,705.0,93.0...|0.417|\n",
      "|[734.0,806.0,69.0...|0.415|\n",
      "|[613.0,759.0,61.0...|0.378|\n",
      "|[748.0,676.0,85.0...|0.422|\n",
      "|[669.0,588.0,97.0...|0.411|\n",
      "|[667.0,845.0,68.0...|0.381|\n",
      "|[758.0,890.0,64.0...|0.436|\n",
      "|[726.0,670.0,88.0...|0.422|\n",
      "|[583.0,794.0,55.0...|0.371|\n",
      "|[676.0,746.0,72.0...|  0.4|\n",
      "|[767.0,699.0,89.0...|0.433|\n",
      "|[637.0,597.0,86.0...|0.374|\n",
      "|[609.0,724.0,69.0...|0.382|\n",
      "|[776.0,733.0,83.0...|0.437|\n",
      "|[701.0,832.0,66.0...| 0.39|\n",
      "|[650.0,709.0,74.0...|0.386|\n",
      "|[804.0,668.0,95.0...|0.453|\n",
      "|[713.0,614.0,94.0...|0.404|\n",
      "+--------------------+-----+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.select(['features','label']).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 337,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of train Dataset : 911\n",
      "Size of test Dataset : 321\n"
     ]
    }
   ],
   "source": [
    "train, test = df.randomSplit([0.75, 0.25])\n",
    "print(f\"Size of train Dataset : {train.count()}\" )\n",
    "print(f\"Size of test Dataset : {test.count()}\" )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 338,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lr = LinearRegression()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 339,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Fit the model\n",
    "lr_model = lr.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predictions_df=lr_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predictions_df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 340,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions=lr_model.evaluate(test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 345,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8855561089304634\n"
     ]
    }
   ],
   "source": [
    "print(model_predictions.r2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 346,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.00013305453514672318\n"
     ]
    }
   ],
   "source": [
    "print(model_predictions.meanSquaredError)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 348,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#import the GLM model\n",
    "from pyspark.ml.regression import GeneralizedLinearRegression\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 365,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "glr = GeneralizedLinearRegression()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 366,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "glr_model = glr.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 351,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DenseVector([0.0003, 0.0001, 0.0002, -0.6258, 0.461])"
      ]
     },
     "execution_count": 351,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glr_model.coefficients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 352,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Coefficients:\n",
       "    Feature Estimate Std Error T Value P Value\n",
       "(Intercept)   0.1887    0.0169 11.1450  0.0000\n",
       "      var_1   0.0003    0.0000 22.5404  0.0000\n",
       "      var_2   0.0001    0.0000  4.3525  0.0000\n",
       "      var_3   0.0002    0.0001  1.7469  0.0810\n",
       "      var_4  -0.6258    0.0707 -8.8489  0.0000\n",
       "      var_5   0.4610    0.0626  7.3697  0.0000\n",
       "\n",
       "(Dispersion parameter for gaussian family taken to be 0.0001)\n",
       "    Null deviance: 0.9886 on 905 degrees of freedom\n",
       "Residual deviance: 0.1355 on 905 degrees of freedom\n",
       "AIC: -5429.3467"
      ]
     },
     "execution_count": 352,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glr_model.summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 367,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions=glr_model.evaluate(test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 368,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|var_1|var_2|var_3|var_4|var_5|label|            features|         prediction|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|  473|  499|   73|0.281|0.228|0.315|[473.0,499.0,73.0...| 0.3168228205906638|\n",
      "|  498|  672|   61|0.288|0.238|0.325|[498.0,672.0,61.0...|0.33224574821552433|\n",
      "|  513|  698|   61|0.298|0.236|0.339|[513.0,698.0,61.0...| 0.3314803948399759|\n",
      "|  527|  569|   75|0.297|0.239|0.341|[527.0,569.0,75.0...| 0.3341226140464902|\n",
      "|  532|  690|   69|0.303|0.245|0.351|[532.0,690.0,69.0...| 0.3399564421746622|\n",
      "|  534|  609|   69|0.304|0.229|0.329|[534.0,609.0,69.0...|0.32848240366793496|\n",
      "|  536|  531|   83|0.292|0.214|0.318|[536.0,531.0,83.0...|  0.328257290790113|\n",
      "|  541|  830|   60|0.302|0.229| 0.33|[541.0,830.0,60.0...| 0.3418227186125283|\n",
      "|  543|  615|   76|0.294|0.233|0.333|[543.0,615.0,76.0...|0.34119126494490326|\n",
      "|  550|  631|   76|0.306|0.235|0.318|[550.0,631.0,76.0...| 0.3377949136360593|\n",
      "|  550|  789|   54|0.305|0.238|0.359|[550.0,789.0,54.0...| 0.3439728742988891|\n",
      "|  557|  659|   71|0.295|0.234|0.355|[557.0,659.0,71.0...|0.34713268891908144|\n",
      "|  559|  613|   75|0.293|0.235|0.359|[559.0,613.0,75.0...|0.34788019044256024|\n",
      "|  569|  620|   77|0.302|0.247|0.349|[569.0,620.0,77.0...|0.35188383904315806|\n",
      "|  569|  711|   65|0.305|0.237| 0.34|[569.0,711.0,65.0...| 0.3479171474465841|\n",
      "|  570|  655|   66|0.311|0.246| 0.34|[570.0,655.0,66.0...| 0.3459593594270007|\n",
      "|  570|  662|   73| 0.31|0.247|0.337|[570.0,662.0,73.0...| 0.3486539304568965|\n",
      "|  570|  786|   57|0.301|0.249|0.366|[570.0,786.0,57.0...|0.35870641360119765|\n",
      "|  572|  646|   71|0.311|0.235|0.329|[572.0,646.0,71.0...| 0.3419971659838297|\n",
      "|  573|  634|   75|0.308|0.244|0.342|[573.0,634.0,75.0...|0.34846159793536235|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.predictions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 369,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1939.8866851859273"
      ]
     },
     "execution_count": 369,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions.aic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 359,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "336.9915101112133"
      ]
     },
     "execution_count": 359,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glr = GeneralizedLinearRegression(family='Binomial')\n",
    "glr_model = glr.fit(train)\n",
    "model_predictions=glr_model.evaluate(test)\n",
    "model_predictions.aic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 360,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "266.53028779813695"
      ]
     },
     "execution_count": 360,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glr = GeneralizedLinearRegression(family='Poisson')\n",
    "glr_model = glr.fit(train)\n",
    "model_predictions=glr_model.evaluate(test)\n",
    "model_predictions.aic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 364,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1903.815764540875"
      ]
     },
     "execution_count": 364,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "glr = GeneralizedLinearRegression(family='Gamma')\n",
    "glr_model = glr.fit(train)\n",
    "model_predictions=glr_model.evaluate(test)\n",
    "model_predictions.aic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 363,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1939.8866851859273"
      ]
     },
     "execution_count": 363,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glr = GeneralizedLinearRegression(family='Tweedie')\n",
    "glr_model = glr.fit(train)\n",
    "model_predictions=glr_model.evaluate(test)\n",
    "model_predictions.aic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 370,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.regression import DecisionTreeRegressor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 371,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dec_tree = DecisionTreeRegressor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 372,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Train model.  This also runs the indexer.\n",
    "dec_tree_model = dec_tree.fit(train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 373,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseVector(5, {0: 0.9667, 1: 0.0124, 2: 0.0045, 3: 0.0053, 4: 0.0112})"
      ]
     },
     "execution_count": 373,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dec_tree_model.featureImportances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Make predictions.\n",
    "model_predictions = dec_tree_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|var_1|var_2|var_3|var_4|var_5|label|            features|         prediction|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|  464|  640|   66|0.283| 0.22|0.301|[464.0,640.0,66.0...|            0.31925|\n",
      "|  501|  774|   51|0.285|0.219|0.315|[501.0,774.0,51.0...|               0.33|\n",
      "|  533|  660|   62|0.296|0.233| 0.33|[533.0,660.0,62.0...|               0.33|\n",
      "|  534|  609|   69|0.304|0.229|0.329|[534.0,609.0,69.0...|            0.31925|\n",
      "|  559|  613|   75|0.293|0.235|0.359|[559.0,613.0,75.0...|0.34612195121951217|\n",
      "|  562|  587|   80|0.308|0.235|0.344|[562.0,587.0,80.0...|0.34612195121951217|\n",
      "|  564|  648|   74|0.294|0.236|0.337|[564.0,648.0,74.0...|0.34612195121951217|\n",
      "|  568|  708|   57|0.311|0.247|0.347|[568.0,708.0,57.0...|0.34612195121951217|\n",
      "|  569|  544|   82|0.304| 0.24|0.343|[569.0,544.0,82.0...|0.34612195121951217|\n",
      "|  571|  577|   83|0.298|0.251|0.368|[571.0,577.0,83.0...| 0.3534705882352937|\n",
      "|  573|  656|   75|0.313|0.242|0.345|[573.0,656.0,75.0...|0.34612195121951217|\n",
      "|  574|  556|   85|0.303|0.243|0.368|[574.0,556.0,85.0...|0.34612195121951217|\n",
      "|  575|  680|   68|  0.3|0.241|0.344|[575.0,680.0,68.0...|0.34612195121951217|\n",
      "|  576|  759|   57|0.313|0.254| 0.35|[576.0,759.0,57.0...|0.35724999999999996|\n",
      "|  578|  633|   76|0.309|0.249|0.337|[578.0,633.0,76.0...| 0.3534705882352937|\n",
      "|  578|  733|   62|0.299|0.231|0.348|[578.0,733.0,62.0...|0.35724999999999996|\n",
      "|  579|  497|   91|0.304|0.225|0.352|[579.0,497.0,91.0...|0.34612195121951217|\n",
      "|  581|  724|   64|0.314|0.248|0.346|[581.0,724.0,64.0...|0.35724999999999996|\n",
      "|  582|  791|   52| 0.31| 0.24|0.359|[582.0,791.0,52.0...|0.35724999999999996|\n",
      "|  584|  680|   63|0.298|0.234| 0.35|[584.0,680.0,63.0...|0.34612195121951217|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import RegressionEvaluator\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The r-square value of DecisionTreeRegressor is 0.8093834699203476\n",
      "The rmse value of DecisionTreeRegressor is 0.014111932287681688\n"
     ]
    }
   ],
   "source": [
    "# R2 value of the model on test data \n",
    "dt_evaluator = RegressionEvaluator(metricName='r2')\n",
    "dt_r2 = dt_evaluator.evaluate(model_predictions)\n",
    "print(f'The r-square value of DecisionTreeRegressor is {dt_r2}')\n",
    "\n",
    "# RMSE value of the model on test data \n",
    "dt_evaluator = RegressionEvaluator(metricName='rmse')\n",
    "dt_rmse = dt_evaluator.evaluate(model_predictions)\n",
    "print(f'The rmse value of DecisionTreeRegressor is {dt_rmse}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 374,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.regression import RandomForestRegressor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rf = RandomForestRegressor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Train model.  This also runs the indexer.\n",
    "rf_model = rf.fit(train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseVector(5, {0: 0.4395, 1: 0.045, 2: 0.0243, 3: 0.2725, 4: 0.2188})"
      ]
     },
     "execution_count": 125,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_model.featureImportances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_model.getNumTrees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions = rf_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|var_1|var_2|var_3|var_4|var_5|label|            features|         prediction|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|  464|  640|   66|0.283| 0.22|0.301|[464.0,640.0,66.0...|0.32620528864346005|\n",
      "|  501|  774|   51|0.285|0.219|0.315|[501.0,774.0,51.0...| 0.3315483384956547|\n",
      "|  533|  660|   62|0.296|0.233| 0.33|[533.0,660.0,62.0...| 0.3279678481672696|\n",
      "|  534|  609|   69|0.304|0.229|0.329|[534.0,609.0,69.0...|0.32849289839181284|\n",
      "|  559|  613|   75|0.293|0.235|0.359|[559.0,613.0,75.0...|0.34431157381674565|\n",
      "|  562|  587|   80|0.308|0.235|0.344|[562.0,587.0,80.0...|0.34633939901939004|\n",
      "|  564|  648|   74|0.294|0.236|0.337|[564.0,648.0,74.0...|0.34660634825283587|\n",
      "|  568|  708|   57|0.311|0.247|0.347|[568.0,708.0,57.0...| 0.3580627405761989|\n",
      "|  569|  544|   82|0.304| 0.24|0.343|[569.0,544.0,82.0...| 0.3489929326622356|\n",
      "|  571|  577|   83|0.298|0.251|0.368|[571.0,577.0,83.0...| 0.3538474256428915|\n",
      "|  573|  656|   75|0.313|0.242|0.345|[573.0,656.0,75.0...|0.35044723951023526|\n",
      "|  574|  556|   85|0.303|0.243|0.368|[574.0,556.0,85.0...|0.35245659207672475|\n",
      "|  575|  680|   68|  0.3|0.241|0.344|[575.0,680.0,68.0...| 0.3508903073912221|\n",
      "|  576|  759|   57|0.313|0.254| 0.35|[576.0,759.0,57.0...| 0.3667486309616494|\n",
      "|  578|  633|   76|0.309|0.249|0.337|[578.0,633.0,76.0...|0.35079976893996123|\n",
      "|  578|  733|   62|0.299|0.231|0.348|[578.0,733.0,62.0...| 0.3466773379187134|\n",
      "|  579|  497|   91|0.304|0.225|0.352|[579.0,497.0,91.0...|0.34216301643346114|\n",
      "|  581|  724|   64|0.314|0.248|0.346|[581.0,724.0,64.0...| 0.3584021380733537|\n",
      "|  582|  791|   52| 0.31| 0.24|0.359|[582.0,791.0,52.0...| 0.3583885211293452|\n",
      "|  584|  680|   63|0.298|0.234| 0.35|[584.0,680.0,63.0...|  0.350146054051407|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The r-square value of RandomForestRegressor is 0.8215863293044671\n",
      "The rmse value of RandomForestRegressor is 0.01365275410722947\n"
     ]
    }
   ],
   "source": [
    "# Select (prediction, true label) and compute test error\n",
    "# R2 value of the model on test data \n",
    "rf_evaluator = RegressionEvaluator(metricName='r2')\n",
    "rf_r2 = rf_evaluator.evaluate(model_predictions)\n",
    "print(f'The r-square value of RandomForestRegressor is {rf_r2}')\n",
    "\n",
    "# RMSE value of the model on test data \n",
    "rf_evaluator = RegressionEvaluator(metricName='rmse')\n",
    "rf_rmse = rf_evaluator.evaluate(model_predictions)\n",
    "print(f'The rmse value of RandomForestRegressor is {rf_rmse}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 375,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.regression import GBTRegressor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 376,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "gbt = GBTRegressor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 377,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Train model.  This also runs the indexer.\n",
    "gbt_model = gbt.fit(train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 378,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseVector(5, {0: 0.2632, 1: 0.1731, 2: 0.2107, 3: 0.2369, 4: 0.1161})"
      ]
     },
     "execution_count": 378,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gbt_model.featureImportances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 379,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions = gbt_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 380,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|var_1|var_2|var_3|var_4|var_5|label|            features|         prediction|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "|  473|  499|   73|0.281|0.228|0.315|[473.0,499.0,73.0...|0.31944398058332746|\n",
      "|  498|  672|   61|0.288|0.238|0.325|[498.0,672.0,61.0...|0.34408906567588476|\n",
      "|  513|  698|   61|0.298|0.236|0.339|[513.0,698.0,61.0...|0.34431612842121734|\n",
      "|  527|  569|   75|0.297|0.239|0.341|[527.0,569.0,75.0...|0.34117798058332727|\n",
      "|  532|  690|   69|0.303|0.245|0.351|[532.0,690.0,69.0...|0.33924893158758207|\n",
      "|  534|  609|   69|0.304|0.229|0.329|[534.0,609.0,69.0...| 0.3173858203666073|\n",
      "|  536|  531|   83|0.292|0.214|0.318|[536.0,531.0,83.0...|0.32116220648685906|\n",
      "|  541|  830|   60|0.302|0.229| 0.33|[541.0,830.0,60.0...|0.32107791151173165|\n",
      "|  543|  615|   76|0.294|0.233|0.333|[543.0,615.0,76.0...| 0.3191040462701389|\n",
      "|  550|  631|   76|0.306|0.235|0.318|[550.0,631.0,76.0...| 0.3363031821392068|\n",
      "|  550|  789|   54|0.305|0.238|0.359|[550.0,789.0,54.0...|0.34106264054462493|\n",
      "|  557|  659|   71|0.295|0.234|0.355|[557.0,659.0,71.0...| 0.3484122333478269|\n",
      "|  559|  613|   75|0.293|0.235|0.359|[559.0,613.0,75.0...| 0.3493445975338579|\n",
      "|  569|  620|   77|0.302|0.247|0.349|[569.0,620.0,77.0...| 0.3584779456850221|\n",
      "|  569|  711|   65|0.305|0.237| 0.34|[569.0,711.0,65.0...|0.35012490558846804|\n",
      "|  570|  655|   66|0.311|0.246| 0.34|[570.0,655.0,66.0...|0.34654473492263055|\n",
      "|  570|  662|   73| 0.31|0.247|0.337|[570.0,662.0,73.0...| 0.3565935879935122|\n",
      "|  570|  786|   57|0.301|0.249|0.366|[570.0,786.0,57.0...| 0.3512264728004555|\n",
      "|  572|  646|   71|0.311|0.235|0.329|[572.0,646.0,71.0...|0.34847298989434594|\n",
      "|  573|  634|   75|0.308|0.244|0.342|[573.0,634.0,75.0...|0.34833018914621916|\n",
      "+-----+-----+-----+-----+-----+-----+--------------------+-------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 381,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The r-square value of GradientBoostedRegressor is 0.8477273892307596\n",
      "The rmse value of GradientBoostedRegressor is 0.013305445803592103\n"
     ]
    }
   ],
   "source": [
    " #Select (prediction, true label) and compute test error\n",
    "# R2 value of the model on test data \n",
    "gbt_evaluator = RegressionEvaluator(metricName='r2')\n",
    "gbt_r2 = gbt_evaluator.evaluate(model_predictions)\n",
    "print(f'The r-square value of GradientBoostedRegressor is {gbt_r2}')\n",
    "\n",
    "# RMSE value of the model on test data \n",
    "gbt_evaluator = RegressionEvaluator(metricName='rmse')\n",
    "gbt_rmse = gbt_evaluator.evaluate(model_predictions)\n",
    "print(f'The rmse value of GradientBoostedRegressor is {gbt_rmse}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " # Classification "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 382,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Load csv Dataset \n",
    "df=spark.read.csv('bank_data.csv',inferSchema=True,header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 383,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "41188"
      ]
     },
     "execution_count": 383,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#number of records\n",
    "df.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 384,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['age',\n",
       " 'job',\n",
       " 'marital',\n",
       " 'education',\n",
       " 'default',\n",
       " 'housing',\n",
       " 'loan',\n",
       " 'contact',\n",
       " 'month',\n",
       " 'day_of_week',\n",
       " 'duration',\n",
       " 'campaign',\n",
       " 'pdays',\n",
       " 'previous',\n",
       " 'poutcome',\n",
       " 'emp.var.rate',\n",
       " 'cons.price.idx',\n",
       " 'cons.conf.idx',\n",
       " 'euribor3m',\n",
       " 'nr.employed',\n",
       " 'target_class']"
      ]
     },
     "execution_count": 384,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: integer (nullable = true)\n",
      " |-- job: string (nullable = true)\n",
      " |-- marital: string (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- default: string (nullable = true)\n",
      " |-- housing: string (nullable = true)\n",
      " |-- loan: string (nullable = true)\n",
      " |-- contact: string (nullable = true)\n",
      " |-- month: string (nullable = true)\n",
      " |-- day_of_week: string (nullable = true)\n",
      " |-- duration: integer (nullable = true)\n",
      " |-- campaign: integer (nullable = true)\n",
      " |-- pdays: integer (nullable = true)\n",
      " |-- previous: integer (nullable = true)\n",
      " |-- poutcome: string (nullable = true)\n",
      " |-- emp.var.rate: double (nullable = true)\n",
      " |-- cons.price.idx: double (nullable = true)\n",
      " |-- cons.conf.idx: double (nullable = true)\n",
      " |-- euribor3m: double (nullable = true)\n",
      " |-- nr.employed: double (nullable = true)\n",
      " |-- target_class: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#dataype of input data \n",
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+-----+\n",
      "|target_class|count|\n",
      "+------------+-----+\n",
      "|          no|36548|\n",
      "|         yes| 4640|\n",
      "+------------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('target_class').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 395,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## Sample Data \n",
    "\n",
    "df=spark.read.csv('binary_class.csv',inferSchema=True,header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 396,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9501"
      ]
     },
     "execution_count": 396,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 397,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['age',\n",
       " 'job',\n",
       " 'marital',\n",
       " 'education',\n",
       " 'default',\n",
       " 'housing',\n",
       " 'loan',\n",
       " 'target_class']"
      ]
     },
     "execution_count": 397,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 398,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- age: integer (nullable = true)\n",
      " |-- job: string (nullable = true)\n",
      " |-- marital: string (nullable = true)\n",
      " |-- education: string (nullable = true)\n",
      " |-- default: string (nullable = true)\n",
      " |-- housing: string (nullable = true)\n",
      " |-- loan: string (nullable = true)\n",
      " |-- target_class: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 399,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+-----+\n",
      "|target_class|count|\n",
      "+------------+-----+\n",
      "|          no| 4861|\n",
      "|         yes| 4640|\n",
      "+------------+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('target_class').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 400,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.sql import functions as F\n",
    "from pyspark.sql import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 401,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df=df.withColumn(\"label\", F.when(df.target_class =='no', F.lit(0)).otherwise(F.lit(1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 402,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----+\n",
      "|label|count|\n",
      "+-----+-----+\n",
      "|    1| 4640|\n",
      "|    0| 4861|\n",
      "+-----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.groupBy('label').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 403,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 404,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def cat_to_num(df):\n",
    "    \n",
    "    for col in df.columns:\n",
    "        stringIndexer = StringIndexer(inputCol=col, outputCol=col+\"_index\")\n",
    "        model = stringIndexer.fit(df)\n",
    "        indexed = model.transform(df)\n",
    "\n",
    "        encoder = OneHotEncoder(inputCol=col+\"_index\", outputCol=col+\"_vec\")\n",
    "        df = encoder.transform(indexed)\n",
    "    df_assembler = VectorAssembler(inputCols=['age','marital_vec','education_vec','default_vec','housing_vec','loan_vec'], outputCol=\"features\")\n",
    "    df = df_assembler.transform(df)  \n",
    "    return df.select(['features','label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 405,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df_new=cat_to_num(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 406,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+\n",
      "|            features|label|\n",
      "+--------------------+-----+\n",
      "|(16,[0,1,8,11,13,...|    0|\n",
      "|(16,[0,1,5,13,14]...|    0|\n",
      "|(16,[0,1,5,11,12,...|    0|\n",
      "|(16,[0,1,9,11,13,...|    0|\n",
      "|(16,[0,1,5,11,13,...|    0|\n",
      "|(16,[0,1,6,13,14]...|    0|\n",
      "|(16,[0,1,7,11,13,...|    0|\n",
      "|(16,[0,1,10,13,14...|    0|\n",
      "|(16,[0,2,7,11,12,...|    0|\n",
      "|(16,[0,2,5,11,12,...|    0|\n",
      "|(16,[0,1,10,13,14...|    0|\n",
      "|(16,[0,2,5,11,12,...|    0|\n",
      "|(16,[0,2,5,11,13,...|    0|\n",
      "|(16,[0,3,8,11,12,...|    0|\n",
      "|(16,[0,1,9,11,12,...|    0|\n",
      "|(16,[0,1,6,12,15]...|    0|\n",
      "|(16,[0,1,9,11,12,...|    0|\n",
      "|(16,[0,1,9,12,15]...|    0|\n",
      "|(16,[0,1,6,11,12,...|    0|\n",
      "|(16,[0,2,6,13,14]...|    0|\n",
      "+--------------------+-----+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df_new.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 407,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+-----+\n",
      "|label|count|\n",
      "+-----+-----+\n",
      "|    1| 4640|\n",
      "|    0| 4861|\n",
      "+-----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df_new.groupBy('label').count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 408,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of train Dataset : 7112\n",
      "Size of test Dataset : 2389\n"
     ]
    }
   ],
   "source": [
    "train, test = df_new.randomSplit([0.75, 0.25])\n",
    "print(f\"Size of train Dataset : {train.count()}\" )\n",
    "print(f\"Size of test Dataset : {test.count()}\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 409,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import LogisticRegression\n",
    "lr = LogisticRegression()\n",
    "lr_model = lr.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 410,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.0273012403206,0.0860736571977,0.934384469874,-0.016522545073,-12.9801090231,-13.469005707,-13.8118640649,-13.3288968418,-13.6023406483,-13.7420942762,-13.2940419118,1.3921700408,0.29133718088,-0.0418899174088,0.248880549172,0.218739541667]\n"
     ]
    }
   ],
   "source": [
    "print(lr_model.coefficients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 411,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lr_summary=lr_model.summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 412,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6732283464566929"
      ]
     },
     "execution_count": 412,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_summary.accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 413,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7161951680555275"
      ]
     },
     "execution_count": 413,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_summary.areaUnderROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 414,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.6948130277442702, 0.6543730242360379]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(lr_summary.precisionByLabel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 415,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+------------------+\n",
      "|         threshold|         precision|\n",
      "+------------------+------------------+\n",
      "|0.9999990369941395|0.8378378378378378|\n",
      "|0.8220499817791618|0.7647058823529411|\n",
      "|0.7968126369992325|0.7734806629834254|\n",
      "| 0.773470522075655|0.7892857142857143|\n",
      "|0.7612946912751203|0.7645502645502645|\n",
      "|0.7490362058233141|0.7540983606557377|\n",
      "|0.7382521496703287|0.7345132743362832|\n",
      "|0.7279553838380405|0.7336448598130841|\n",
      "|0.7193784940839227|0.7327823691460055|\n",
      "|0.7122887285503471|0.7333333333333333|\n",
      "|0.7047990854701919|0.7276736493936052|\n",
      "|0.6983210316089354|0.7244174265450861|\n",
      "|0.6895039811348334|0.7183226982680037|\n",
      "|0.6828171155434885|0.7237615449202351|\n",
      "|0.6769693285452459|0.7262658227848101|\n",
      "|0.6702627333842052|  0.72619926199262|\n",
      "|0.6639336170490286|0.7254355400696864|\n",
      "|0.6574455562430638|0.7240924092409241|\n",
      "|0.6519155436691656|0.7210327455919395|\n",
      "|0.6462446109735182|0.7204628501827041|\n",
      "+------------------+------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lr_summary.precisionByThreshold.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 416,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>threshold</th>\n",
       "      <th>precision</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.999997</td>\n",
       "      <td>0.888889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.822831</td>\n",
       "      <td>0.758242</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.796765</td>\n",
       "      <td>0.720779</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.781076</td>\n",
       "      <td>0.775362</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.765103</td>\n",
       "      <td>0.765714</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.754244</td>\n",
       "      <td>0.748869</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.744988</td>\n",
       "      <td>0.735010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.734741</td>\n",
       "      <td>0.730263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.724762</td>\n",
       "      <td>0.733529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.715474</td>\n",
       "      <td>0.737047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.708786</td>\n",
       "      <td>0.719723</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0.702597</td>\n",
       "      <td>0.722281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.695284</td>\n",
       "      <td>0.724568</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.688088</td>\n",
       "      <td>0.722173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.681117</td>\n",
       "      <td>0.725329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.674421</td>\n",
       "      <td>0.725155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.668611</td>\n",
       "      <td>0.720351</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0.661924</td>\n",
       "      <td>0.722572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0.656448</td>\n",
       "      <td>0.722892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.650878</td>\n",
       "      <td>0.723828</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.644870</td>\n",
       "      <td>0.723443</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.639960</td>\n",
       "      <td>0.717709</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.635317</td>\n",
       "      <td>0.714527</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.630246</td>\n",
       "      <td>0.713053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0.623876</td>\n",
       "      <td>0.707946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.617901</td>\n",
       "      <td>0.706495</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.612636</td>\n",
       "      <td>0.707502</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.606000</td>\n",
       "      <td>0.704753</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.600930</td>\n",
       "      <td>0.701747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0.594389</td>\n",
       "      <td>0.694794</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <td>0.327900</td>\n",
       "      <td>0.558217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>0.317414</td>\n",
       "      <td>0.554896</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>74</th>\n",
       "      <td>0.310297</td>\n",
       "      <td>0.552522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75</th>\n",
       "      <td>0.301704</td>\n",
       "      <td>0.550283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76</th>\n",
       "      <td>0.291484</td>\n",
       "      <td>0.548157</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>0.284373</td>\n",
       "      <td>0.545623</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>0.272514</td>\n",
       "      <td>0.543264</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>0.263907</td>\n",
       "      <td>0.541188</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80</th>\n",
       "      <td>0.255375</td>\n",
       "      <td>0.538664</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>81</th>\n",
       "      <td>0.246590</td>\n",
       "      <td>0.536194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>82</th>\n",
       "      <td>0.237072</td>\n",
       "      <td>0.533755</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83</th>\n",
       "      <td>0.228888</td>\n",
       "      <td>0.531527</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>84</th>\n",
       "      <td>0.222984</td>\n",
       "      <td>0.528706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>0.215858</td>\n",
       "      <td>0.526190</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>86</th>\n",
       "      <td>0.208246</td>\n",
       "      <td>0.524283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>87</th>\n",
       "      <td>0.202366</td>\n",
       "      <td>0.521726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>88</th>\n",
       "      <td>0.196637</td>\n",
       "      <td>0.519064</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>89</th>\n",
       "      <td>0.189674</td>\n",
       "      <td>0.516238</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90</th>\n",
       "      <td>0.183445</td>\n",
       "      <td>0.513514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>91</th>\n",
       "      <td>0.177102</td>\n",
       "      <td>0.510990</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>92</th>\n",
       "      <td>0.171059</td>\n",
       "      <td>0.508041</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>93</th>\n",
       "      <td>0.165146</td>\n",
       "      <td>0.505442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>94</th>\n",
       "      <td>0.157890</td>\n",
       "      <td>0.503258</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>0.152254</td>\n",
       "      <td>0.500367</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>0.145723</td>\n",
       "      <td>0.497670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>0.137731</td>\n",
       "      <td>0.494944</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>0.130693</td>\n",
       "      <td>0.492640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>0.119356</td>\n",
       "      <td>0.490494</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>100</th>\n",
       "      <td>0.109876</td>\n",
       "      <td>0.488529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>101</th>\n",
       "      <td>0.093096</td>\n",
       "      <td>0.488134</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>102 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     threshold  precision\n",
       "0     0.999997   0.888889\n",
       "1     0.822831   0.758242\n",
       "2     0.796765   0.720779\n",
       "3     0.781076   0.775362\n",
       "4     0.765103   0.765714\n",
       "5     0.754244   0.748869\n",
       "6     0.744988   0.735010\n",
       "7     0.734741   0.730263\n",
       "8     0.724762   0.733529\n",
       "9     0.715474   0.737047\n",
       "10    0.708786   0.719723\n",
       "11    0.702597   0.722281\n",
       "12    0.695284   0.724568\n",
       "13    0.688088   0.722173\n",
       "14    0.681117   0.725329\n",
       "15    0.674421   0.725155\n",
       "16    0.668611   0.720351\n",
       "17    0.661924   0.722572\n",
       "18    0.656448   0.722892\n",
       "19    0.650878   0.723828\n",
       "20    0.644870   0.723443\n",
       "21    0.639960   0.717709\n",
       "22    0.635317   0.714527\n",
       "23    0.630246   0.713053\n",
       "24    0.623876   0.707946\n",
       "25    0.617901   0.706495\n",
       "26    0.612636   0.707502\n",
       "27    0.606000   0.704753\n",
       "28    0.600930   0.701747\n",
       "29    0.594389   0.694794\n",
       "..         ...        ...\n",
       "72    0.327900   0.558217\n",
       "73    0.317414   0.554896\n",
       "74    0.310297   0.552522\n",
       "75    0.301704   0.550283\n",
       "76    0.291484   0.548157\n",
       "77    0.284373   0.545623\n",
       "78    0.272514   0.543264\n",
       "79    0.263907   0.541188\n",
       "80    0.255375   0.538664\n",
       "81    0.246590   0.536194\n",
       "82    0.237072   0.533755\n",
       "83    0.228888   0.531527\n",
       "84    0.222984   0.528706\n",
       "85    0.215858   0.526190\n",
       "86    0.208246   0.524283\n",
       "87    0.202366   0.521726\n",
       "88    0.196637   0.519064\n",
       "89    0.189674   0.516238\n",
       "90    0.183445   0.513514\n",
       "91    0.177102   0.510990\n",
       "92    0.171059   0.508041\n",
       "93    0.165146   0.505442\n",
       "94    0.157890   0.503258\n",
       "95    0.152254   0.500367\n",
       "96    0.145723   0.497670\n",
       "97    0.137731   0.494944\n",
       "98    0.130693   0.492640\n",
       "99    0.119356   0.490494\n",
       "100   0.109876   0.488529\n",
       "101   0.093096   0.488134\n",
       "\n",
       "[102 rows x 2 columns]"
      ]
     },
     "execution_count": 416,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 417,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/matplotlib/axes/_axes.py:545: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.\n",
      "  warnings.warn(\"No labelled objects found. \"\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xd81eXd//HXlU0gAUL2Zu9pZONAUaRWqrdasc46utTW2ru1v96j475v77Z3a9VqbWsV9x7FUWtFpswwZRMgIQmQSUIG2dfvj+sgKQSIcJKTc/J+Ph7nAeecb875fB9ffXPlWl9jrUVERAJLkK8LEBER71O4i4gEIIW7iEgAUriLiAQghbuISABSuIuIBCCFu4hIAFK4i4gEIIW7iEgACvHVF8fGxtrMzExffb2IiF9at25dqbU27kzH+SzcMzMzyc7O9tXXi4j4JWNMXnuOU7eMiEgAUriLiAQghbuISABSuIuIBCCFu4hIADpjuBtjnjbGFBtjtpzifWOMedQYk2OM2WyMmeD9MkVE5ItoT8t9PjD7NO9fAQz2PO4G/nDuZYmIyLk4Y7hba5cC5ac5ZC7wnHVWAX2MMUneKvBEa3PL+eWHO2hp0e0BRUROxRt97ilAfqvnBZ7XTmKMudsYk22MyS4pKTmrL9uUX8EfFu+hqq7prH5eRKQ78Ea4mzZea7NZba39k7U2y1qbFRd3xtWzbYrpGQZAeW3DWf28iEh34I1wLwDSWj1PBQ544XPb1PdYuNco3EVETsUb4b4AuMUza2YyUGmtPeiFz21TTKQL98MKdxGRUzrjxmHGmJeBi4BYY0wB8J9AKIC19kngA2AOkAPUArd3VLGgbhkRkfY4Y7hba+ed4X0LfMdrFZ3BsW4ZtdxFRE7N71ao9gwLJiw4SC13EZHT8LtwN8bQt2eoWu4iIqfhd+EO0DcyjPKaRl+XISLSZflluMf0DKNC3TIiIqfkl+Het2eY+txFRE7DL8M9JjJMfe4iIqfhl+Het2cYFUcbadbmYSIibfLLcI+JDMVaqDyqQVURkbb4ZbhrfxkRkdPzy3A/tgXBYQ2qioi0yS/DvW+kWu4iIqfjl+Eeo/1lREROyy/D/fOWu7plRETa5Jfh3iMsmB6hwWq5i4icgl+GO7iuGe0vIyLSNr8N9749QzVbRkTkFPw33CPDNFtGROQUzngnpi7ncB4UbSGmZzL7y2t9XY2ISJfkfy33rW/DKzeSEN6olruIyCn4X7hHJwOQFlJJVV0Tjc0tPi5IRKTr8b9wj0oCINGUA9qCQESkLf4X7p6We5wtA+CwpkOKiJzEb8O9b3MpoP1lRETa4n/hHtoDIvoQ1VACqFtGRKQt/hfuANHJRNYXA2q5i4i0xT/DPSqJsNoiQDtDioi0xT/DPTqJoKqD9AoP0c6QIiJt8M9wj0qGmmJiI4PUchcRaYN/hnt0EtgWBkTUUF6rqZAiIifyz3CPctMhM8Mr1XIXEWmDf4Z7tFulmhZcqdkyIiJt8NNwTwEgKeiw5rmLiLTBP8M9sh8EhxFPGbUNzVQeVb+7iEhr/hnuxkBUIv3DKgF45tN9Pi5IRKRraVe4G2NmG2N2GmNyjDEPtvF+ujFmkTFmgzFmszFmjvdLPUFUMn2by5g9MpGnlu2jrLq+w79SRMRfnDHcjTHBwOPAFcAIYJ4xZsQJh/0b8Jq1djxwA/CEtws9SXQSHDnADy4fQm1DE08s3tPhXyki4i/a03KfCORYa/daaxuAV4C5JxxjgWjP33sDB7xX4ilEJUPVQQbF9eLa81J5fmUehRVHO/xrRUT8QXvCPQXIb/W8wPNaaz8FbjLGFAAfAPd6pbrTiU6Cxlqoq+S7lw4B4JGPd3X414qI+IP2hLtp4zV7wvN5wHxrbSowB3jeGHPSZxtj7jbGZBtjsktKSr54ta157shE1UFS+vTg5ikZvLGugIXbi87tc0VEAkB7wr0ASGv1PJWTu13uAF4DsNauBCKA2BM/yFr7J2ttlrU2Ky4u7uwqPsZz0w6OuFIeuGwIo1J6c89LG9iUX3Funy0i4ufaE+5rgcHGmP7GmDDcgOmCE47ZD1wCYIwZjgv3c2yan0GrljtAZFgIf7n1fGKjwvj6/LXkldV06NeLiHRlZwx3a20TcA/wd2A7blbMVmPMz40xV3kOewC4yxizCXgZuM1ae2LXjXcdC/cjBz9/KS4qnPm3T6TZWm57Zq32nRGRbiukPQdZaz/ADZS2fu0/Wv19GzDNu6WdQWiEW6la9c89RAPjevHsV/uz4IXfc++zLTx114VEhAZ3amkiIr7mnytUj4lK/rzPvbWxBS/z78Hz+cWhb/L751+hpaVjf4kQEelq/DvcPQuZTrJvGfTtT2wPw/fy7mHF0/8KHdxLJCLSlfh3uEclfT6g+rn6ajiwHkZeTa/vreazmFlML/gzHz75Q+oam31Tp4hIJ/PvcI9OhpoSaKw7/lr+Kmhpgv4zMD36MOaeV9je7zIuO/Rn/u/Rh7WKVUS6Bf8O9+QJ7s+cj4+/tm8ZBIVC2iQAgoODGP7N56iKGcH9R37NfY+8xPubD7bxYSIigcO/w33gTOgZD5tePv5a7jJIOQ/Ceh5/LbQHvW97nfCe0TxufslPX1rId15ar7s4iUjA8u9wDw6BMdfDrr9DTRnUHYEDG6H/jJOP7Z1CyI0vkxBczUd9f8WGrdu57OElLNnVsWutRER8wb/DHWDsDdDSCFvehP2rwDZDZhvhDpCahbnpDfo2lbIo7rcM7lHDrU+v4aG/baexueXk45vqoSL/5NdFRLq4di1i6tISR0PCaNc1kzkNgsMgbeKpj8+YCje9QfgL1/Jir5+xYOA1/HzJEdbsK+fRa4eTdnQn5H3qunf2r4amo3DdfBh5daedkojIuTIdvUvAqWRlZdns7GzvfNiK38NHP4FeidBvINz+wZl/Jm8lvHsflO6ixQSzuyWV/hQSZprc+wmj3G8ABWugeAfc+TEknHiPEhGRzmWMWWetzTrTcf7fcgfX7/6P/4DqQ3Debe37mYwp8J01ULSVoC1vkpG3lg/Ksni/IoOk0Rfxr1dPJSoiFKoOwR8vhFduhLsXQY++7uebG6F8L5TsgMpC6JsJ8cPd9MyiLa7VX7QFbAuYYAjtAYMudYPAIWHH62isg6Y693nN9VBXCUcroP6I+7mQMAjrBUnj3BiDiEg7BEZa9Ip3wbn7720Ppp6KMZA4ChJHEQFc2dzCnoW7eXxRDgvzlvGra8cwbVAiXP8czP8SvHQDxPSHQ1tcqLc0nqGuBAgOd+MAdZWw9s8Q0cfVWlsGpbvgSGH7ao1KhqzbYcItEJXY/nMUkW4pMLplAPJWwLLfwg0vQkj4OX3UurzD/Ovrm9hbWsPXJqVz/6whxO56Fd77vtusLHE0JIyE+BEQNxR6p0L5PijeBpX57r20Scf3nAdoaoC9i+CzN2DfEre6Nm4o9BvkWubBoW68IKI39OgD4VHQ0uJa81WHYOOLsOcT15pPGgvpk93YQswAF/yR/SDI/8fHReT02tstEzjh7mV1jc3839938vSn+wgLCeKG89P5xvRUkmJ6+66osj2w8SXYvxIK17nunGOCw91vBOPmweDL/7nrR0QChsLdS/aWVPOHxXt4e0MhxsBVY1O4Y3p/RiRHn/mHO1JTAxRvhcoCt6d92W7Y+g7UFLuun7hhbmO1XgnQeBSOHnb9+H0yIHmc68NPGKV/BET8jMLdywoO1/LUsn28lp1PbUMz0wfFcv+swZyXEePr0o5rbnJdN9v+ChV5bsfM6mI3mBsZ41btluW4/n9wrf2ksZB6Poy70Y0/iEiXpnDvIJW1jby4Jo+nl+dSWl3PpcPj+cHlQxmW6OOWfHtZC4dz4cAG17VTkA0HN7ounhFz4cIfuTEDEemSFO4drLahiWc+zeXJJXuorm/iqrHJ3H/pEDJje575h7ua2nJY9Qf3aKiCHjGupR8ZC6lZMPQKSJusqZgiXYDCvZNU1Dbw5JK9zF+xj8Zmy3XnpXLvJYNJ6dPD16V9cbXlsOEF16VTWwZVRVCw1k35jOjtunD6DYbYwRAe7Wb4BIVAzzg3Yyg6RX34Ih1M4d7JiqvqeGLRHl5avR+L5fqsNL598SD/DPnW6qtcP/7uf0DxdijdDfWVbR9rgiBxDAy6BAZe4hZ19ejr1hOIiFco3H3kQMVRHl+Uw2vZbsOxa89L49sXDSQtJtLHlXmJtVBTCo010NLsNlerKXazdg7nQu5yyF/jFm6BG7TtlQA9erv5/GG9XMs/5TzX5dMnQ+Ev8gUo3H2s4HAtTyzewxvZBTRby9xxydwxvT8jk304T76zHK1wm68dznULsKqL3Qydhmr3Z+lutyEbuO6e+BGulT/gIhgy+5wXoYkEMoV7F3Goso4/Ld3Ly2v2c7SxmayMvtwyNZM5oxIJCe6mK0qbG6FoKxRmQ9E2191TvNUFf4++MOpaGDbHTdEMj/J1tSJdisK9i6msbeT1dfk8vyqPvLJa0mJ68O2LBnHNhBTCQ4J9XZ7vtTS77Rk2vgQ73ndTM02QW2iVMRUyprk/e8b6ulIRn1K4d1EtLZaFO4r5/aIcNuVXkBgdwY2T0rk+K43E3hG+Lq9rqK9ys3T2r3JbLeSvPd6NkzAahs523TfJE7SfjnQ7CvcuzlrLpzll/HHpHpbtLiXIwMxh8cybmM5FQ+MJDtIg4+eaGtxCq9xlsPtjyF/ltlIO7+02T0uf7PbxD49yr/XNcDuFigQghbsf2V9Wy6vZ+3ktu4CSqnqSe0fw1fPTmTcpjfgoteZPUlsOOQshb7nbN79k+8nHRCW5aZkJI91gbdxQt9+OBmvFzync/VBjcwsLtxfx4ur9LNtdSmiw4coxydw+LZMxqX18XV7XdfSw2zyt/ogblC3bAwc3uUfZbmjx3F0rONxNwUyf7KZhJozUVEzxOwp3P7evtIZnV+TyenY+NQ3NDEuM4ivjU7hqbDLJ/r4wqjM1Nbg7ZhVvc3vp7F/pQv9Y4IdFwYALYcp3IH2Kgl66PIV7gDhS18g7Gwp5e0MhG/ZXAHB+Zl/mjE5izugkEqLVbfOFNdS46ZdFW+DgZtj6Nhwtd9sgZ30dhl0JPfv5ukqRNincA1BeWQ0LNh7g/c8OsuNQFcbA1IH9uGZ8KrNHJdIzXBt7nZWGWtj8its4rXSXm4KZMc3tkjlirgZnpUtRuAe4nOJq3t10gLc3FLK/vJbIsGCuPS+VO6cPIL1fgGx10NmshUObYfu7sG0BlO48HvRDr3C3Tkwco83RxKcU7t2EtZZ1eYd5ZW0+f91YSHOLZc7oJO6Y3p/x6X19XZ5/K97uumy2vu1a9AAhEW53zJTz3CNjmrvjlUgnUbh3Q4cq63hmxT5eWrWfqvomxqX14fZpmcwelahVsOfqyEEoWOOmXhZmu0HZY6toB86E8TfBkCsgVGMg0rEU7t1YdX0Tb64rYP6KXPaV1tA3MpS541K4Liu1e2xc1hmaG90MnG0LYNPLcKQQgsPc9Mrk8ZCSBZnTNNVSvM6r4W6MmQ08AgQDT1lr/7eNY64HfgpYYJO19sbTfabCveO1tFiW5ZTyWnY+/9haRENzC5P6x3DPzEFMHxSLUeh4R0sz7FsCe5fAgfVwYNPxPe+jU92eOGkT3SN+pO5oJefEa+FujAkGdgGzgAJgLTDPWrut1TGDgdeAmdbaw8aYeGtt8ek+V+HeuSpqG3hjXQF/XraXoiP1jE3tzc1TMrlCs2y8r6UFSna4bY9zl7s9cqoPufciY2HMV2H813SvWjkr3gz3KcBPrbWXe57/GMBa+1CrY34F7LLWPtXeAhXuvlHf1Myb6wr587K97CutoUdoMLNHJXLLlAwNwHYUa6Ey3/XXb18AO//mbl2YMBqGXwlD50DiaHXfSLt4M9yvBWZba+/0PL8ZmGStvafVMe/gWvfTcF03P7XWftjGZ90N3A2Qnp5+Xl5eXvvPSLzKWsv6/Yd5c30h7246QFVdE1MG9OPbFw9Ul01HqymDz16Dre9A/mrAQlQyZE53/fTpU6DfIAjSILiczJvhfh1w+QnhPtFae2+rY94DGoHrgVRgGTDKWltxqs9Vy73rqK5v4uXV+/nzsr0UV9WTldGXH10xjPMzY3xdWuCrLoZdH7r71OZ+6m5ZCBAa6fayTznPzcbJnAZhPX1bq3QJnd0t8ySwylo73/N8IfCgtXbtqT5X4d711Dc181p2AY8t3E1xVT2XDIvn1qmZTBnYj9DueteozmQtlOW4vewPbnYLqgrXuSmXwWFuTv3wK932CFGJvq5WfMSb4R6C63K5BCjEDajeaK3d2uqY2bhB1luNMbHABmCctbbsVJ+rcO+6jjY0M39FLk8u2UPl0Ub6RIZy2YgEbpiYzgT1y3euxjrYv8Jtcbzzb1C+BzCQMgEyZ0D/GZA2GcJ7+bpS6STengo5B/gdrj/9aWvtfxtjfg5kW2sXGNdB+xtgNtAM/Le19pXTfabCveura2xm6a4S/rblEP/YVkR1fRMTM2O4+4IBXDxMNxTpdNa6WTjb34Wcj12rvqUJgkLdNsYDZ7pH4mj11wcwLWISr6qub+LVtfk8vXwfhRVHiekZxkVD45g5LJ6Zw+KJDNN0yk7XUOOmWe5dDHsWQdFn7vWIPm5wNmOq2wsncZS78bgEBIW7dIjG5hb+sa2Ij7YeYvGuEipqG+kVHsJV45KZd346o1O1AtZnqg7BvqVuQdW+pVCx//h7vdPcAG3iKBf4aZMgKsF3tcpZU7hLh2tusazNLee17Hze33yQ+qYWhiT0Yu64FL48Jlm7U/paVZFrzR/6DA5tcfvXl+5y958F6Jvp+uv7z4D+F0CfdJ+WK+2jcJdOVXm0kQUbC/nrxgNk5x0GYGJmDDdMTGPO6CQiQtUH3CU0HnVhn7/adensXwm1nnkPfdJd2B+76Xj8SAjSLKmuRuEuPlNwuJa/bjzA69n55JbVEh0RwpVjk5k7NpnzM2MI0kBs19HS4m4wvm+Z2y4hfzVUF7n3Inq7sE+f5Lp04oe77h0tcPMphbv4XEuLZdW+Ml5dm89HW4s42thMUu8IrhiVxOxRiZyX0Vczbroaa6Eiz7Xq81a4R9nu4++HRUHsYIgbCrFDPI/B0Le/bmLSSRTu0qXU1Dfx8fYi/rrxAMt3l9LQ3EJsrzCuHJPM9VlpjEiO9nWJcipHK9wUzOLt7lG6E0p2QtXB48cEhUDsUDdgmzAK4kdA/DCITlFL38sU7tJlVdU1suTY/HnPVsRjUnszb2I6c8cla1qlv6irdCtqS3e78C/a6gZuqw4cPyY8GvoNdHvlxA5xO2EmjNQ+9+dA4S5+oaK2gXc2FPLK2nx2HKoiKiKEf5mQyo2T0hmSEOXr8uRs1JZ7WvrboHiH+wegbI/bGRNP3oRHu9Z9wkhIGHG8Tz9CU2nPROEufuXYvWCfX5XHB58dpLHZMiolmn+ZkMrccSnE9FR/rt+rr3bdOkWeaZlF21xr/9iNTcAN2MYNdV08cUMheZz7RyA41Hd1dzEKd/FbpdX1LNh4gDfXF7D1wBHCQoK4ckwSN0/OYFxaH21HHEishcoC18ov2uLCv2Sn6+ppOuqOCQ53LfzE0a6FnzDSdfX0SuiWXTsKdwkI2w8e4aXV+3lrfQE1Dc0MS4zytOaTiY/WzagDVksLVOTCgY2eWxdudOF/9PDxY0IjIWaAa9knjvasvh0LPfv5rOzOoHCXgFJd38Q7Gwp5fV0Bm/IrCDJw8dB4bpyUzkVDtYlZt2Ctm6FTtA0O74PyfW6aZtFWd4PyY6JTIWmMu1H5sUfPWN/V7WUKdwlYOcXVvLm+gDfWFVBSVU9y7wiuHJvMjMGxnJ8Zo9Ww3VFtudv//tg++Ac3ua6dYwO40amQNNb14R/bY8dPF2Qp3CXgNTa3sHB7ES+vyWflnjIamlsICwlixqBYrhqXzKwRCZpW2Z3VHfHc8GS9C/uDm9zMnWOBH9HbtepTzoOULBf+0cldPvAV7tKt1DY0sXpfOUt3lfDhlkMcrKwjMiyYy0cmcvX4FKYNilXXjXhm7GzzbKbmCf6irWCb3fs9YlyrPm748ZW4iWOgRx/f1t2Kwl26rZYWy5rccv66sZD3Nx/kSF0TCdHhfGVcCldPSGFYolbDSisNtcfD/pBnF83S3dBQdfyYfoMgeYLr1kka6wZwfTQnX+Eugrub1Cc7inlrfSGLdxbT1GIZnhTNteelcs34FPpq/ry05djgbfF2OLDBPQrX/fOWCzEDIMkT9gkjXSu/E/rxFe4iJyirrue9zQd5a30BmwoqCQsO4vJRiXxtUjqT+sdo/rycWXWxG7Q9uMH14R/YBJWtbooS1st15/Qb7LZb6DfA/SMQM8BtuuaFLZQV7iKnsf3gEV5dm8/bGwqpPNrIqJRo7poxgDmjkwgN1h7m8gUc226hZIfbbqF0lxu4rcw/+digUAgJh9kPwYRbzurrFO4i7VDX2Mxb6wt5avle9pbUkNKnB9+6aCDXZaUSHqIplXIOGmrgcK7bV+dwrnveXA9N9TBirrshyllQuIt8AS0tlkU7i/n9ohw27K8gMTqCuy8YwA0T0zSdUroUhbvIWbDW8mlOGY8u3M2a3HL6RoZy69RMbp2SqcFX6RIU7iLnaG1uOU8u3sPCHcX0DAvm5imZ3DmjP7G9wn1dmnRjCncRL9lx6AiPL9rDe5sPEB4SxA3np3PLlAwGxPXydWnSDSncRbxsT0k1jy/K4d1NB2hstlw4JI6bJmdw8dA4QjTDRjqJwl2kgxRX1fHy6nxeXJ1HcVU9cVHhXDM+ha+en6bWvHQ4hbtIB2tsbmHRjmJeX1fAJzuKaW6xXDo8njtnDNCiKOkwCneRTlRcVceLq/bz/Ko8ymsaGJkczS1TMrhqbAo9wjRfXrxH4S7iA8cWRT27IpedRVVER4RwfVYat07NJC0m0tflSQBQuIv4kLWWtbmHeW5lLh9uOUSLtVw6PIE7ZwxgYv8YX5cnfqy94a6ldyIdwBjDxP4xTOwfw6HKOp5flctLq/fz0bYiJmbG8J2Zg7hgcKz65aXDqOUu0kmONjTz6tr9/HHpXg5W1jEiKZp5k9KZOy6Z6IhQX5cnfkLdMiJdVENTC2+tL+DZlXlsP3iEiNAg5o5N4RsXDtBUSjkjhbtIF2et5bPCSl5ek89b6wtobG5hzugk7pk5SHeLklNqb7i3a1mdMWa2MWanMSbHGPPgaY671hhjjTFn/GKR7s4Yw5jUPjx0zWg+fXAm37hwIIt3ljDnkWX85O3PKK9p8HWJ4sfO2HI3xgQDu4BZQAGwFphnrd12wnFRwPtAGHCPtfa0zXK13EVOVlnbyO8W7uK5lXn0Cg/h3pmDuHFSurYdls95s+U+Ecix1u611jYArwBz2zjuF8CvgLovVKmIfK53ZCj/+eWR/O27Mxid0pv/en87U//3E377j12UVdf7ujzxI+0J9xSg9f2iCjyvfc4YMx5Is9a+58XaRLqtIQlRvHDnJN745hTOz4zh0YW7ufDXi/nT0j00NLX4ujzxA+35Xa+tibif9+UYY4KAh4HbzvhBxtwN3A2Qnp7evgpFurGszBiyMmPIKa7ioQ928D8f7ODVtfn8+5UjuHBInObJyym1p+VeAKS1ep4KHGj1PAoYBSw2xuQCk4EFbQ2qWmv/ZK3NstZmxcXFnX3VIt3MoPgo/nLb+Tx9WxZNLZbbnlnL3Mc/datfW3wz4026tvYMqIbgBlQvAQpxA6o3Wmu3nuL4xcAPNKAq0jHqm9z+NU8u2UNeWS1DE6L48ZxhXDQ03telSSfw2oCqtbYJuAf4O7AdeM1au9UY83NjzFXnXqqIfBHhIcHMm5jOwu9fyCM3jKOuqZnbnlnLzX9ZzY5DR3xdnnQRWsQk4ucamlp4flUejy7cTVVdIzdNzuD7s4bQJ1I39A5EXl3EJCJdV1hIEHdM78+Sf72Imydn8MKqPC76v8W8sCqPZvXHd1sKd5EA0ScyjJ/NHcX7981gWGIU//bOFr782HLW5pb7ujTxAYW7SIAZnhTNy3dN5vc3judwbQPXPbmS772yQYuguhmFu0gAMsZw5ZhkFj5wIffOHMT7nx1k1sNLWbDpAL4aZ5POpXAXCWCRYSE8cNlQ3rt3Bml9e3Dfyxu449lsthRW+ro06WAKd5FuYGhiFG9+ayr/b84w1uaWc+Vjy/n6/LVszK/wdWnSQRTuIt1ESHAQd18wkOU/mskPLhvChv2HufqJT3nk492aVROAFO4i3UzvHqHcM3Mwy340k6vHpfDwx7u47Zk1lGrANaAo3EW6qV7hIfzm+rH87zWjWbOvnCseWcZHWw/5uizxEoW7SDdmjOGGiem8/e1p9OsZxt3Pr+Oel9Zr2mQAULiLCCOSo1lwz3QemDWEj7YWMevhpSzdVeLrsuQcKNxFBHDbGNx7yWDeu286cb3CufWZNfz2o50abPVTCncR+SdDEqJ45zvTuHZCKo9+ksO8P6/iswLNi/c3CncROUmPsGB+fd1Yfn3tGHYequLLv1/ON57PZuehKl+XJu2kcBeRU7ouK41lP7qY7106mBU5ZVzxyFJ+9u5WquubfF2anIHCXUROKzoilO9dOoRlP7qYGyelM39FLpf+ZgkfbtG0ya5M4S4i7dInMoz/+spo3vzWVPpEhvLNF9bx4JubOdrQ7OvSpA0KdxH5Qiak9+Xde6fz7YsG8mp2PnMfX87uIvXFdzUKdxH5wkKDg/jh7GE8e/tEymsa+NJjy/nZu1sprqrzdWnioXAXkbN2wZA4PvjuDK4el8JzK/O48FeL+dWHO2hoavF1ad2ewl1Ezkl8VAS/vHYMH3//Qi4bmcATi/dw69NrqKxt9HVp3ZrCXUS8on9sTx65YTy/vX4s2XnlXPOHT9lfVuvrsrothbuIeNU1E1J54Y5JlNU0MPfx5by9oUC39vMBhbuIeN2kAf1461tTyejXk/tf3cRNf1nNvtIaX5fVrSjcRaRDDIjrxZvfmsov5o5kc34ll/9uKY98vJv6Js2L7wwKdxHpMMFBhpunZLLwgQu5fGQiD3+8iyt+t4xPc0p9XVrAU7ir9AqpAAAKKklEQVSLSIeLj47gsXnjee7rE2m2lq89tZrvv7aR8poGX5cWsBTuItJpLhgSx9+/dwH3zhzEgo0HuPS3S3hnQ6EGXDuAwl1EOlVEaDAPXDaU9++bQUa/SL736ka+8fw63aDbyxTuIuITQxOjeOObU/nJnOEs3lnC5Q8v5e+6QbfXKNxFxGeCgwx3XTCAd++dTkJ0BN94fh13PZdNfrkWP50rhbuI+NzQRHdrvx/NHsanOaVc8tsl/OajnRzWgOtZM74ayMjKyrLZ2dk++W4R6boOVh7lfz7YwbubDhAeEsRXxqVw69RMRiRH+7q0LsEYs85am3XG4xTuItIV7TxUxfwVuby9oYC6xhZmjUjg+7OGMDype4e8wl1EAkJFbQPPrczjz8v2UlXXxJVjkvjBZUPJjO3p69J8or3h3q4+d2PMbGPMTmNMjjHmwTbe/74xZpsxZrMxZqExJuNsihYROVGfyDDuu2Qwy384k3suHsQnO4qZ9fASfrpgqxZBncYZW+7GmGBgFzALKADWAvOstdtaHXMxsNpaW2uM+RZwkbX2q6f7XLXcReRsFB+p4+GPd/Pq2v30DAvhuqw0bpmS0W1a8t5suU8Ecqy1e621DcArwNzWB1hrF1lrj81dWgWkftGCRUTaIz46goeuGc1H91/ARcPieW5lLhf/ZjG3P7OGRTuKaWnRaleAkHYckwLkt3peAEw6zfF3AH87l6JERM5kUHwUj80bT9GXhvPS6v28uHo/t89fS1pMD26alMGNk9KJigj1dZk+056Wu2njtTb/aTTG3ARkAb8+xft3G2OyjTHZJSUl7a9SROQUEqIjuH/WEFY8OJPH5o0nqXcPHvrbDmb8ahGPL8qhur7J1yX6RHv63KcAP7XWXu55/mMAa+1DJxx3KfAYcKG1tvhMX6w+dxHpKJsLKvjdx7v5ZEcxfSJDuWVyBjdPySQuKtzXpZ0zr02FNMaE4AZULwEKcQOqN1prt7Y6ZjzwBjDbWru7PQUq3EWko23Kr+CxT3JYuKOI0OAgrh6Xwl0XDGBQfC9fl3bWvDrP3RgzB/gdEAw8ba39b2PMz4Fsa+0CY8zHwGjgoOdH9ltrrzrdZyrcRaSz7C2p5i/L9/HGugLqm1q4dHgC37xwAOdl9MWYtnqeuy4tYhIROUFZdT3Prszj+ZW5HK5tZGxaH74+LZM5o5MIDfaPrbYU7iIip1Db0MSb6wp45tNc9pbWkBAdzs2TM5g3MZ1+vbp2v7zCXUTkDFpaLEt2lfD0p/tYtruUsOAgvjw2mW9fPJCBcV2zX7694d6eee4iIgEpKMhw8bB4Lh4WT05xNc+tzOWNdQW8s7GQayekct+lg0np08PXZZ4VtdxFRFopra7niUV7eGFVHgBfGpPETZPTmZDeNQZf1S0jInIOCiuO8scle3hrfSHV9U0MS4zijun9+cr4FJ8OvircRUS8oKa+iXc3HWD+ilx2HKoiuXcEd10wgHkT04kIDe70ehTuIiJeZK1l8c4Snlicw9rcw6TF9OAnc0Zw+ciETu2u8ep+7iIi3Z0xbvD19W9O5cU7JxEZGsI3X1jH155azeaCCl+XdxKFu4jIFzRtUCzv3zedX8wdybaDR7jq959y57PZbD1Q6evSPqdwFxE5CyHBQdw8JZNlP7yYB2YNYc2+Mr706HJ++eGOLrGnvMJdROQcREWEcu8lg1n2o5nccH4af1i8h7ufz6aqrtGndSncRUS8oHePUB66ZjS/mDuSRTtLuOaJFeQUV/usHoW7iIiXGGO4eUomz399IqXV9Xzp0WU8vXyfT7ppFO4iIl42dVAsf7//AqYPiuXn723ja0+tpriqrlNrULiLiHSA+KgInro1i1/9yxg25ldw13PrqGts7rTvV7iLiHQQYwzXn5/Gw18dx6b8Cn781md01sJRhbuISAebPSqRB2YN4e0Nhfxx6d5O+U6Fu4hIJ7hn5iCuHJPELz/cwSc7ijr8+7Sfu4hIJzDG8Otrx1JT30TvHmEd/n0KdxGRTtIjLJhnbp/YKd+lbhkRkQCkcBcRCUAKdxGRAKRwFxEJQAp3EZEApHAXEQlACncRkQCkcBcRCUCmszaxOemLjSkB8s7yx2OBUi+W4w+62znrfANbdztf8N45Z1hr4850kM/C/VwYY7KttVm+rqMzdbdz1vkGtu52vtD556xuGRGRAKRwFxEJQP4a7n/ydQE+0N3OWecb2Lrb+UInn7Nf9rmLiMjp+WvLXURETsPvwt0YM9sYs9MYk2OMedDX9XibMSbNGLPIGLPdGLPVGPNdz+sxxph/GGN2e/7s6+tavckYE2yM2WCMec/zvL8xZrXnfF81xnT83Q06kTGmjzHmDWPMDs+1nhLI19gYc7/nv+ctxpiXjTERgXSNjTFPG2OKjTFbWr3W5vU0zqOeDNtsjJnQETX5VbgbY4KBx4ErgBHAPGPMCN9W5XVNwAPW2uHAZOA7nnN8EFhorR0MLPQ8DyTfBba3ev5L4GHP+R4G7vBJVR3nEeBDa+0wYCzu3APyGhtjUoD7gCxr7SggGLiBwLrG84HZJ7x2qut5BTDY87gb+ENHFORX4Q5MBHKstXuttQ3AK8BcH9fkVdbag9ba9Z6/V+H+p0/BneeznsOeBb7imwq9zxiTCnwJeMrz3AAzgTc8hwTa+UYDFwB/AbDWNlhrKwjga4y761sPY0wIEAkcJICusbV2KVB+wsunup5zgeesswroY4xJ8nZN/hbuKUB+q+cFntcCkjEmExgPrAYSrLUHwf0DAMT7rjKv+x3wQ6DF87wfUGGtbfI8D7TrPAAoAZ7xdEU9ZYzpSYBeY2ttIfB/wH5cqFcC6wjsawynvp6dkmP+Fu6mjdcCcrqPMaYX8CbwPWvtEV/X01GMMVcCxdbada1fbuPQQLrOIcAE4A/W2vFADQHSBdMWT1/zXKA/kAz0xHVNnCiQrvHpdMp/3/4W7gVAWqvnqcABH9XSYYwxobhgf9Fa+5bn5aJjv7p5/iz2VX1eNg24yhiTi+tmm4lryffx/AoPgXedC4ACa+1qz/M3cGEfqNf4UmCftbbEWtsIvAVMJbCvMZz6enZKjvlbuK8FBntG2cNwgzILfFyTV3n6m/8CbLfW/rbVWwuAWz1/vxX4a2fX1hGstT+21qZaazNx1/MTa+3XgEXAtZ7DAuZ8Aay1h4B8Y8xQz0uXANsI0GuM646ZbIyJ9Pz3fex8A/Yae5zqei4AbvHMmpkMVB7rvvEqa61fPYA5wC5gD/ATX9fTAec3Hfcr2mZgo+cxB9cPvRDY7fkzxte1dsC5XwS85/n7AGANkAO8DoT7uj4vn+s4INtznd8B+gbyNQZ+BuwAtgDPA+GBdI2Bl3HjCY24lvkdp7qeuG6Zxz0Z9hluFpHXa9IKVRGRAORv3TIiItIOCncRkQCkcBcRCUAKdxGRAKRwFxEJQAp3EZEApHAXEQlACncRkQD0/wGkJl1nORmNrAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f78195884e0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(precision_threshold)\n",
    "plt.legend(loc='upper left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 418,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.6371681415929203, 0.7105263157894737]"
      ]
     },
     "execution_count": 418,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_summary.recallByLabel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 419,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------------------+\n",
      "|         threshold|              recall|\n",
      "+------------------+--------------------+\n",
      "|0.9999990369941395|0.008867276887871853|\n",
      "|0.8220499817791618|0.018592677345537757|\n",
      "|0.7968126369992325| 0.04004576659038902|\n",
      "| 0.773470522075655| 0.06321510297482838|\n",
      "|0.7612946912751203| 0.08266590389016018|\n",
      "|0.7490362058233141| 0.10526315789473684|\n",
      "|0.7382521496703287|  0.1187070938215103|\n",
      "|0.7279553838380405|  0.1347254004576659|\n",
      "|0.7193784940839227| 0.15217391304347827|\n",
      "|0.7122887285503471| 0.16990846681922198|\n",
      "|0.7047990854701919| 0.18878718535469108|\n",
      "|0.6983210316089354| 0.20451945080091533|\n",
      "|0.6895039811348334| 0.22540045766590389|\n",
      "|0.6828171155434885|  0.2465675057208238|\n",
      "|0.6769693285452459|  0.2625858123569794|\n",
      "|0.6702627333842052|  0.2814645308924485|\n",
      "|0.6639336170490286| 0.29776887871853547|\n",
      "|0.6574455562430638|  0.3137871853546911|\n",
      "|0.6519155436691656|  0.3275171624713959|\n",
      "|0.6462446109735182|  0.3383867276887872|\n",
      "+------------------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lr_summary.recallByThreshold.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 420,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6732283464566929"
      ]
     },
     "execution_count": 420,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_summary.weightedRecall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 421,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6749341958735193"
      ]
     },
     "execution_count": 421,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_summary.weightedPrecision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 422,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+--------------------+\n",
      "|                 FPR|                 TPR|\n",
      "+--------------------+--------------------+\n",
      "|                 0.0|                 0.0|\n",
      "| 0.00165929203539823|0.008867276887871853|\n",
      "|0.005530973451327...|0.018592677345537757|\n",
      "| 0.01133849557522124| 0.04004576659038902|\n",
      "| 0.01631637168141593| 0.06321510297482838|\n",
      "| 0.02461283185840708| 0.08266590389016018|\n",
      "|0.033185840707964605| 0.10526315789473684|\n",
      "|0.041482300884955754|  0.1187070938215103|\n",
      "|0.047289823008849555|  0.1347254004576659|\n",
      "|0.053650442477876106| 0.15217391304347827|\n",
      "|0.059734513274336286| 0.16990846681922198|\n",
      "| 0.06830752212389381| 0.18878718535469108|\n",
      "|  0.0752212389380531| 0.20451945080091533|\n",
      "| 0.08545353982300885| 0.22540045766590389|\n",
      "| 0.09098451327433628|  0.2465675057208238|\n",
      "|  0.0956858407079646|  0.2625858123569794|\n",
      "|  0.1025995575221239|  0.2814645308924485|\n",
      "| 0.10896017699115045| 0.29776887871853547|\n",
      "| 0.11559734513274336|  0.3137871853546911|\n",
      "| 0.12251106194690266|  0.3275171624713959|\n",
      "+--------------------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lr_summary.roc.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 423,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+------------------+\n",
      "|              recall|         precision|\n",
      "+--------------------+------------------+\n",
      "|                 0.0|0.8378378378378378|\n",
      "|0.008867276887871853|0.8378378378378378|\n",
      "|0.018592677345537757|0.7647058823529411|\n",
      "| 0.04004576659038902|0.7734806629834254|\n",
      "| 0.06321510297482838|0.7892857142857143|\n",
      "| 0.08266590389016018|0.7645502645502645|\n",
      "| 0.10526315789473684|0.7540983606557377|\n",
      "|  0.1187070938215103|0.7345132743362832|\n",
      "|  0.1347254004576659|0.7336448598130841|\n",
      "| 0.15217391304347827|0.7327823691460055|\n",
      "| 0.16990846681922198|0.7333333333333333|\n",
      "| 0.18878718535469108|0.7276736493936052|\n",
      "| 0.20451945080091533|0.7244174265450861|\n",
      "| 0.22540045766590389|0.7183226982680037|\n",
      "|  0.2465675057208238|0.7237615449202351|\n",
      "|  0.2625858123569794|0.7262658227848101|\n",
      "|  0.2814645308924485|  0.72619926199262|\n",
      "| 0.29776887871853547|0.7254355400696864|\n",
      "|  0.3137871853546911|0.7240924092409241|\n",
      "|  0.3275171624713959|0.7210327455919395|\n",
      "+--------------------+------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lr_summary.pr.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 424,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAD8CAYAAACLrvgBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VPXZ//H3zRJA2XcQIqJRRCAsEbHW1kfkKWgLasFKa2VR6WOf1lqtCi70qZZWi9bW5aeiImitCogFFbUIaq1VIWxhl7BJAFklEPYk9++PDDaNk+RAZuZMks/rurgyZ+bLnI8HJ/ec71luc3dERESCqBF2ABERqTxUNEREJDAVDRERCUxFQ0REAlPREBGRwFQ0REQkMBUNEREJTEVDREQCU9EQEZHAaoUdINaaN2/uHTp0CDuGiEilsmDBgp3u3qK8cVWuaHTo0IHMzMywY4iIVCpmtjHIOE1PiYhIYCoaIiISmIqGiIgEpqIhIiKBqWiIiEhgKhoiIhKYioaIiASmoiEiUgW8vWwrMxZvjvt6VDRERCq5eet3c9PLi3nh440UFHpc16WiISJSiX22bR/XT55P+yb1ePraDGrWsLiuT0VDRKSS2rLnIMMmzqNu7ZpMHtmbJienxH2dKhoiIpVQ7oGjDH9uHnmH8pk0ojftmpyUkPVWuRsWiohUdYeOFnDDC5ms37mfySN607ltw4StW0VDRKQSKSh0fvnKYuat380jQ3vwjTOaJ3T9mp4SEakk3J17X1/OW8u+4O7LzmZgetuEZ1DREBGpJJ74YC2TP97IDReexvUXdgwlg4qGiEgl8OqCHP7w9moGdW/LmAFnh5ZDRUNEJMm9v3o7d7yaxQVnNGP84HRqxPlajLKoaIiIJLGsnD389MWFnNmqAU9e04uUWuH+2g517WbW38xWm1m2mY0uY9xgM3Mzy0hkPhGRMG3ctZ+Rk+bT9OQUJo04lwZ1a4cdKbyiYWY1gceBAUBnYKiZdY4yrgFwE/BpYhOKiIRnZ95hrp04j4JCZ/LI3rRsWDfsSEC4exq9gWx3X+fuR4CXgUFRxt0H/AE4lMhwIiJh2X84n5GT5rNt7yGeHX4up7eoH3akr4RZNE4BNhVbzok89xUz6wG0d/c3EhlMRCQsRwsK+emLC1m2OZfHhvakZ2qTsCP9hzCvCI92+P+re/qaWQ3gYWB4uW9kNgoYBZCamhqjeCIiieXujH51KR98toP7r+zKJZ1bhR3pa8Lc08gB2hdbbgdsKbbcAOgCvG9mG4A+wMxoB8PdfYK7Z7h7RosWLeIYWUQkfh78+2peXZjDzZekcXXv5PwCHGbRmA+kmdlpZpYCXA3MPPaiu+e6e3N37+DuHYBPgIHunhlOXBGR+Hnh4w08/t5ahvZuzy/6poUdp1ShFQ13zwd+BrwDrASmuPtyM7vXzAaGlUtEJNHeXraVsTOXc8nZrbhvUBfMwrt4rzyh3uXW3WcBs0o8N7aUsRclIpOISCIda9Xao31jHh3ag1o1k/ua6+ROJyJShR1r1dquST2eHXYu9VJqhh2pXCoaIiIh2Jpb1Kq1Tu2aTB6RmFatsaCiISKSYLkHjzJ84nz2Hcpn0ohzad80Ma1aY0FFQ0QkgQ4dLeCG5zNZtzOPCT/uxTltG4Ud6bio3auISIIUFDq3TAmvVWssaE9DRCQB3J373ljBrKXhtWqNBRUNEZEEePKDdUz61wau/2Z4rVpjQUVDRCTOpi/M4YG3VzEwvS13Xhpeq9ZYUNEQEYmjDz7bwe3TsvjG6c0YP6RbqK1aY0FFQ0QkTpbm5HLjXxaQ1qoBT/24F3VqJf/Fe+VR0RARiYONu/YzYtI8mpyUwuQkadUaCzrlVkQkxnbmHWbYxHnkFzqvXJc8rVpjQXsaIiIxtP9wPtdNms8Xew/x7LDkatUaCyoaIiIxcrSgkP/960KWbs7l0aE96XVqcrVqjQVNT4mIxIC7M2b6Ut5fvYPfX9mVfknYqjUWtKchIhIDD/39M6YtyOEXfdMYmqStWmNBRUNEpIJe+GQjj72XzdDe7bn5kuRt1RoLKhoiIhXw9rIvGDtjGZec3TLpW7XGgoqGiMgJmr9hNze9vIju7Rvz6NCeSd+qNRaq/n+hiEgcrNm2j+smzadd48rTqjUWVDRERI7Tf7RqHdmbppWkVWssqGiIiByHY61a9x7K57nhlatVayyoaIiIBHToaAGjIq1an7ymF11OqVytWmNBF/eJiARQWOjcOmUJn67fzZ+v7s430ypfq9ZY0J6GiEg53J1731jBm0u3cuelnRjU/ZSwI4VGRUNEpBwT/lHUqnXkBadxQyVu1RoLoRYNM+tvZqvNLNvMRkd5/RYzW2FmWWY2x8xODSOniFRfry3K4fdvreK73dpw92VnV/mL98oTWtEws5rA48AAoDMw1Mw6lxi2CMhw927ANOAPiU0pItXZh2t2cNvULM7v2IyHrkqv9K1aYyHMPY3eQLa7r3P3I8DLwKDiA9z9PXc/EFn8BGiX4IwiUk0t25zL/7ywgDNa1uepa6tGq9ZYCLNonAJsKracE3muNNcBb8U1kYgI8PmuAwx/bj6NT0ph8sjeNKwirVpjIcxTbqPt53nUgWbXABnAt0t5fRQwCiA1tereklhE4m9X3mGGPTePowWFvDzqPFpVoVatsRDmnkYO0L7YcjtgS8lBZnYJcBcw0N0PR3sjd5/g7hnuntGiRYu4hBWRqu/AkXxGTs5ky56DTByewRktG4QdKemEWTTmA2lmdpqZpQBXAzOLDzCzHsBTFBWM7SFkFJFqIr+gkJ/9dRFLc/bw6NAe9Dq1adiRklJoRcPd84GfAe8AK4Ep7r7czO41s4GRYeOB+sBUM1tsZjNLeTsRkRPm7tz52lLmrtrOfZd34b/PaR12pKQV6m1E3H0WMKvEc2OLPb4k4aFEpNp5ePZnTMnM4aaLz+BH5+lysLLoinARqdZe/HQjj8zN5gcZ7fllvzPDjpP0VDREpNp6Z/kX3PO3ZVzcqSXjrqj6rVpjQUVDRKqlzA27uemlRXRt15jHftijWrRqjYVyt5KZDTGzBpHHd5vZdDPrGf9oIiLxkb19H9dNzqRt43pMHJbBSSnqEhFUkNJ6j7vvM7NvAt8BJgNPxDeWiEh8bNt7iGET51O7Zg2eH9mbZvXrhB2pUglSNAoiPy8DnnD3GUD1aYgrIlXG3kNHGTZxHnsOHGHSiOrXqjUWghSNzWb2FHAVMMvM6gT8eyIiSeNwflGr1uzteTz54+rZqjUWgvzyv4qiC/D6u/seoClwW1xTiYjEUGGhc8uUJXyybjcPDknnwjTdbuhEBSkaT7n7dHdfA+DuW4EfxzeWiEhsuDu/fXMlb2ZtZcyATlzeo/q2ao2FIEXjnOILkeZJveITR0Qktp7+cB0TP1rPiAs6MOpb1btVayyUWjTMbIyZ7QO6mdneyJ99wHZgRsISioicoL8t2szvZq3ism5tuOeyzrp4LwZKLRru/nt3bwCMd/eGkT8N3L2Zu49JYEYRkeP2zzU7uW3aEvp0bMof1ao1Zsq9osXdx5jZKcCpxce7+z/iGUxE5EQt25zLT17I5PQW9ZlwbYZatcZQuUXDzO6nqNfFCv59zYYDKhoiknQ27f53q9ZJI9SqNdaCXDt/BXBWaV3zRESSxe79R7h24r9btbZupFatsRbk7Kl1gEq1iCS1A0fyGTlpPlv2HOTZYWrVGi9B9jQOAIvNbA7w1d6Gu98Ut1QiIschv6CQn/91EVk5e3jiml5kdFCr1ngJUjRmUqJ3t4hIsnB37nptGXNWbee3l3fhO2rVGldBzp6abGb1gFR3X52ATCIigT387hpeydzEzy8+g2v6qFVrvAXpp/E9YDHwdmS5u5lpz0NEQvfipxt5ZM4arspoxy1q1ZoQQQ6E/x/QG9gD4O6LgdPimElEpFx/j7Rq/a+zWjDuiq662jtBghSNfHfPLfGcxyOMiEgQCzbu5ueRVq2P/6gntdWqNWGCHAhfZmY/BGqaWRpwE/Cv+MYSEYkue3ueWrWGKEh5/jlFd7o9DLwE7AVujmcoEZFoilq1zqNWjRpMHqFWrWEIcvbUAeCuyB8RkVAUb9X6yk/OJ7WZWrWGodSiYWZ/cvebzex1ohzDcPeBcU0mIhJxOL+Anzy/gOzteTw34ly1ag1RWXsaL0R+PpiIICIi0RQWOrdOWcLH63bx8A/UqjVspRYNd18QeZgJHHT3Qviqc19MJhLNrD/wZ6Am8Iy731/i9TrA8xR1CtwF/MDdN8Ri3SJSOYybtZI3srYyekAnrujRLuw41V6QA+FzgOKTh/WAdyu64kjxeRwYAHQGhppZ5xLDrgO+dPczgIeBByq6XhGpPJ7+xzqe/ed6hn+jAz9Rq9akEORctbrunndswd3zzCwWR6B6A9nuvg7AzF4GBlHUt+OYQRRdXAgwDXjMzMzddZ1IEpi/YTdrt+eVPzCAWP6DxvL/Do9hsiC5yh0S4E3KG1HeWwT5eFV0HUHeY2feYZ54fy2Xdm3NPd9Vq9ZkEaRo7Deznu6+EMDMegEHY7DuU4BNxZZzgPNKG+Pu+WaWCzQDdhYfZGajgFEAqampMYgm5fnX2p388OlPw44hVdyFac3541XdqalWrUkjSNG4GZhqZlsiy22AH8Rg3dH+Lyj55SPIGNx9AjABICMjQ3shcZZ3OJ/bp2XRodlJPD/yPGrXit0H2qL+k5/ge8XorWL66yrAm5W3DYL8d5U3JMi39vLfI0iOcgaV83LDurW0h5FkglynMd/MOgFnUfRPvMrdj8Zg3TlA+2LL7YAtpYzJMbNaQCNgdwzWLRUw7s2VbN5zkKk6V16k2inrOo2L3X2umV1Z4qU0M8Pdp1dw3fMj73UasJmiPuQ/LDFmJjAM+BgYDMzV8YxwffDZDl6a9zmjvtVRjW5EqqGy9jS+BcwFvhflNQcqVDQixyh+BrxD0Sm3E919uZndC2S6+0zgWeAFM8umaA/j6oqsUyom9+BR7piWxRkt6+s21CLVVFlF48vIz2fd/Z/xWLm7zwJmlXhubLHHh4Ah8Vi3HL97X1/BjrzDPPXjXtStXTPsOCISgrKu0xgR+flIIoJIcpu9YhuvLszhxm+fTnr7xmHHEZGQlLWnsdLMNgAtzSyr2PMGuLt3i2sySRpf7j/CmOlL6dS6ATf1TQs7joiEqKzbiAw1s9YUHXPQzQmrsbEzl7PnwBEmjzyXlFpqdiNSnZV19tQcd+9rZu+4+8ZEhpLkMWvpVl5fsoVb+p3JOW11Z1GR6q6s6ak2ZvZt4Htm9hIlLsM5doW4VF078w5z99+W0a1dI2686PSw44hIEiiraIwFRlN00d0fS7zmwMXxCiXhc3fuem0peYfzeWhIunowiwhQ9jGNacA0M7vH3e9LYCZJAjMWb+Gd5dsYM6ATaa0ahB1HRJJEkK+P48zsGjMbC2BmqWbWO865JERf5B5i7Ixl9Dq1CddfqNtRi8i/BSkajwPnA0Mjy/siz0kV5O6Mnp7FkYJCHhySrruLish/CFI0znP3/wUOAbj7l0BKXFNJaKZkbuL91Tu4o38nTmt+cthxRCTJBCkaRyNd9hzAzFoAhXFNJaHI+fIA972xkj4dmzLs/A5hxxGRJBSkaDwCvAa0MrNxwD+B38U1lSRcYaFzx6tZuDvjB6dTQ9NSIhJFkH4aL5rZAqBv5KnL3X1lfGNJor346UY+yt7FuCu60L6pemSISHRBOvcB1OHfF/fpeEYVs3HXfn43axUXpjXnh73VLldESlfu9JSZ/QJ4EWgBtAT+YmY/j3cwSYzCQue2qVnUqmE88P1uaq0pImUKsqdxHUVnUO0HMLMHKOqk92g8g0liTPxoPfM27Gb84G60bVwv7DgikuSCHAg3oKDYcgHl95yXSmDtjjzGv7Oavp1aMrhXu7DjiEglEGRP4zngUzN7LbJ8OUVtWKUSyy8o5NYpS6hbuya/v7KrpqVEJJAgZ0/90czeB75J0R7GCHdfFO9gEl8TPlzH4k17+PPV3WnZsG7YcUSkkiirn8a5QHN3fytyG/SFkecHmlkNd1+QqJASW6u/2MefZq9hQJfWDExvG3YcEalEyjqmMR6Idj3GishrUgkdLSjklimLaVC3Fr+9vIumpUTkuJRVNJq5+4aST7p7NtAsbokkrh5/L5vlW/Yy7oquNKtfJ+w4IlLJlFU0yjr/Uneyq4SWbc7lsbnZXN69Lf27tA47johUQmUVjXfNbJyVmL8ws98Ac+MbS2LtcH4Bt0xZTNOTU/jNwC5hxxGRSqqss6duBZ4Bss1sceS5dCATuD7ewSS2/vTuGj7blsdzw8+l0Um1w44jIpVUWe1e9wNDzawjcE7k6eXuvi4hySRmFn7+JU99sJarMtrxX51ahh1HRCqxINdprANiWijMrCnwCtAB2ABcFWnuVHxMd+AJoCFFV6GPc/dXYpmjOjh0tIBfTV1C64Z1ufu7ncOOIyKVXJDbiMTDaGCOu6cBcyLLJR0ArnX3c4D+wJ/MrHECM1YJ499Zzbod+/nD4HQa1tW0lIhUTFhFYxAwOfJ4MkW3JvkP7v6Zu6+JPN4CbKfoTrsS0Lz1u5n40Xqu6ZPKN9Oahx1HRKqAoP00MLOWwFf3m3D3zyuw3lbuvjXyPlsj713WuntT1MdjbQXWWa3sP5zPr6YuoX2Tkxgz4Oyw44hIFVFu0TCzgcBDQFuKvu2fStGV4ueU8/feBaJdDHDX8QQ0szbAC8Awd4/am9zMRgGjAFJT1UQI4P63VrHpywO8fEMfTq4T+LuBiEiZgvw2uQ/oA7zr7j3M7L+AoeX9JXe/pLTXzGybmbWJ7GW0oagYRRvXEHgTuNvdPyljXROACQAZGRleXraq7qPsnbzwyUZGXnAa53XUxfsiEjtBjmkcdfddQI3IjQrfA7pXcL0zgWGRx8OAGSUHmFkK8BrwvLtPreD6qo19h45y+7QsOjY/mdv7nxV2HBGpYoIUjT1mVh/4B/Cimf0ZyK/geu8H+pnZGqBfZBkzyzCzZyJjrgK+BQw3s8WRPxUtVlXeb99Yydbcgzx4VTp1a9cMO46IVDFBpqcGAQeBXwI/AhoBv6nISiN7Ln2jPP/V1ebu/hfgLxVZT3Xz3qrtvJK5if/59un0TG0SdhwRqYKC7GmMdfdCd89398nu/ghwR7yDyfHJPXCU0dOzOLNVfX7ZLy3sOCJSRQUpGv2iPDcg1kGkYv7v9eXsyjvCQ0O6U6eWpqVEJD7K6tx3I/BToKOZZRV7qQHwUbyDSXBvL/uC1xZt5hd90+jarlHYcUSkCivrmMZfgbeA3/Oft/nY5+6745pKAtuVd5i7XlvKOW0b8rOLzwg7johUcaVOT7l7rrtvcPehQHvgYnffSNGpt6clLKGUyt25Z8Yy9h46ykNXpVO7Zlh3hRGR6qLc3zJm9muKDnyPiTyVgs5qSgqvZ21l1tIvuPmSM+nUumHYcUSkGgjy1fQKYCCwH766eWCDeIaS8m3fd4ixM5aR3r4xP/lWx7DjiEg1EaRoHHF3BxzAzNQfPGTuzp3Tl3LwSAEPDUmnlqalRCRBgvy2mWJmTwGNzewG4F3g6fjGkrK8unAz767czm3fOYszWtYPO46IVCNBOvc9aGb9gL3AWRRd7Dc77skkqq25B/nN68s5t0MTRlyg8xFEJLEC3TM7UiRmm1lzYFd8I0lp3J3bp2WRX+A8OCSdmjUs7EgiUs2UOj1lZn3M7H0zm25mPcxsGbAM2GZm/RMXUY55ad4mPlyzkzGXduLUZjq0JCKJV9aexmPAnRTdoHAuMMDdPzGzTsBLwNsJyCcRm3YfYNybK/jG6c245rxTw44jItVUWQfCa7n73yO9LL441gTJ3VclJpocU1jo3DZtCWbGHwZ3o4ampUQkJGUVjeKtVQ+WeK3ad8dLpOc/3sAn63Zz92Vn067JSWHHEZFqrKzpqXQz2wsYUC/ymMhy3bgnEwDW79zP/W+v4qKzWvCDc9uHHUdEqrlSi4a76/7aISsodH41dQkpNWtw/5XdMNO0lIiEK9AptxKOZz5cx4KNX/LwD9Jp3Ug7dyISPt1/Ikmt2baPh2Z/xn93bsXl3U8JO46ICKCikZTyCwq5deoSTk6pybgrumpaSkSShqanktAT768lKyeXx3/YkxYN6oQdR0TkK9rTSDIrtuzlkblr+G63NlzWrU3YcURE/oOKRhI5kl/ILVMW06heCvcN6hJ2HBGRr9H0VBJ5dO4aVn2xj6evzaDJySlhxxER+RrtaSSJJZv28P/eX8uVPU+hX+dWYccREYlKRSMJHDpawK1Tl9Cifh1+/b1zwo4jIlKqUIqGmTU1s9lmtibys0kZYxua2WYzeyyRGRPp4dmfkb09j/u/35VG9WqHHUdEpFRh7WmMBua4exowJ7JcmvuADxKSKgQLNu5mwofrGNq7PRed1TLsOCIiZQqraAwCJkceTwYujzbIzHoBrYC/JyhXQh08UsCvpmbRtlE97rqsc9hxRETKFVbRaOXuWwEiP7/2FdvMagAPAbclOFvCPPD2Ktbv3M/4Id2oX0cnsolI8ovbbyozexdoHeWluwK+xU+BWe6+qbzbaJjZKGAUQGpq6vHEDM3Ha3cx6V8bGHb+qXzj9OZhxxERCSRuRcPdLyntNTPbZmZt3H2rmbUBtkcZdj5woZn9FKgPpJhZnrt/7fiHu08AJgBkZGQkfYOovMP53DZtCR2ancQdAzqFHUdEJLCw5kRmAsOA+yM/Z5Qc4O4/OvbYzIYDGdEKRmU07s2VbN5zkKk/OZ+TUjQtJSKVR1jHNO4H+pnZGqBfZBkzyzCzZ0LKlBAffLaDl+Z9zg0XdiSjQ9Ow44iIHBdzT/rZnOOSkZHhmZmZYceIKvfgUb7z8D+oX7cWb/z8m9StreaIIpIczGyBu2eUN05XhCfQva+vYEfeYR4akq6CISKVkopGgsxesY1XF+Zw47dPJ71947DjiIicEBWNBPhy/xHGTF9Kp9YNuKlvWthxREROmE7dSYCxM5ez58ARJo88l5RaqtMiUnnpN1iczVq6ldeXbOGmvmmc07ZR2HFERCpERSOOduYd5u6/LaPrKY248aLTw44jIlJhKhpx4u7c9dpS8g7l89BV6dSuqU0tIpWffpPFyYzFW3hn+TZu+e8zObNVg7DjiIjEhIpGHGzbe4ixM5bRM7UxN1zYMew4IiIxo6IRY+7O6FezOFJQyIND0qlZo+w79IqIVCYqGjE2NTOH91bv4PbvdKJji/phxxERiSkVjRjK+fIA976xgvNOa8rwb3QIO46ISMypaMRIYaFzx6tZFLozfnA6NTQtJSJVkIpGjLz46UY+yt7FnZeeTWqzk8KOIyISFyoaMbBx135+N2sVF6Y150fnVY52syIiJ0JFo4IKC53bpmZRq4bxwPe7UV4/cxGRykxFo4ImfrSeeRt2M/Z7nWnbuF7YcURE4kpFowLW7shj/Dur6dupJYN7tQs7johI3KlonKD8gkJunbKEurVr8vsru2paSkSqBfXTOEETPlzH4k17+PPV3WnZsG7YcUREEkJ7Gidg9Rf7+NPsNQzo0pqB6W3DjiMikjAqGsfpaEEht0xZTIO6tfjt5V00LSUi1Yqmp47T4+9ls3zLXp68pifN6tcJO46ISEJpT+M4LNucy2NzsxnUvS39u7QJO46ISMKpaAR0OL+AW6csoenJKfxm4DlhxxERCYWmpwL687trWL1tHxOHZ9D4pJSw44iIhEJ7GgEs+vxLnvxgLUN6tePiTq3CjiMiEppQioaZNTWz2Wa2JvKzSSnjUs3s72a20sxWmFmHxCaFQ0cLuHXqElo3rMs93+uc6NWLiCSVsPY0RgNz3D0NmBNZjuZ5YLy7nw30BrYnKN9Xxr+zmnU79vPA4G40rFs70asXEUkqYRWNQcDkyOPJwOUlB5hZZ6CWu88GcPc8dz+QuIgwb/1uJn60nh+dl8qFaS0SuWoRkaQUVtFo5e5bASI/W0YZcyawx8ymm9kiMxtvZjWjvZmZjTKzTDPL3LFjR0wC7j+cz6+mLqFdk3rceenZMXlPEZHKLm5nT5nZu0DrKC/dFfAtagEXAj2Az4FXgOHAsyUHuvsEYAJARkaGn0Dcr7n/rVVs+vIAL9/Qh5Pr6CQzERGIY9Fw90tKe83MtplZG3ffamZtiH6sIgdY5O7rIn/nb0AfohSNWPsoeycvfLKRkRecxnkdm8V7dSIilUZY01MzgWGRx8OAGVHGzAeamNmxgwkXAyviHWzfoaPcPi2Ljs1P5vb+Z8V7dSIilUpYReN+oJ+ZrQH6RZYxswwzewbA3QuAXwFzzGwpYMDT8Q722zdWsjX3IA9elU7d2lEPoYiIVFuhTNa7+y6gb5TnM4Hriy3PBrolKtd7q7bzSuYm/ufbp9MzNeqlIyIi1ZquCI/IPXCU0dOzOLNVfX7ZLy3sOCIiSUmnBUUcKSik6ymN+UXfNOrU0rSUiEg0KhoRLRrU4ZlhGWHHEBFJapqeEhGRwFQ0REQkMBUNEREJTEVDREQCU9EQEZHAVDRERCQwFQ0REQlMRUNERAIz95i0n0gaZrYD2FiBt2gO7IxRnFhSruOjXMdHuY5PVcx1qruX26K0yhWNijKzTHdPukvDlev4KNfxUa7jU51zaXpKREQCU9EQEZHAVDS+bkLYAUqhXMdHuY6Pch2faptLxzRERCQw7WmIiEhg1bJomFl/M1ttZtlmNjrK63XM7JXI65+aWYckyTXczHaY2eLIn+ujvU8cck00s+1mtqyU183MHonkzjKznkmS6yIzyy22vcYmKFd7M3vPzFaa2XIz+0WUMQnfZgFzJXybmVldM5tnZksiuX4TZUzCP5MBc4XymYysu6aZLTKzN6K8Fr/t5e7V6g9QE1gLdARSgCVA5xJjfgo8GXl8NfBKkuQaDjwWwjb7FtATWFbK65cCbwEG9AE+TZJcFwFvhLC92gA9I48bAJ9F+bdM+DYLmCvh2yyyDepHHtcGPgX6lBgTxmcySK5QPpORdd8C/DXav1c8t1d13NPoDWS7+zp3PwK8DAwqMWYQMDnyeBrQ18wsCXKFwt2T7c7GAAAC20lEQVT/AewuY8gg4Hkv8gnQ2MzaJEGuULj7VndfGHm8D1gJnFJiWMK3WcBcCRfZBnmRxdqRPyUPtib8MxkwVyjMrB1wGfBMKUPitr2qY9E4BdhUbDmHr39wvhrj7vlALtAsCXIBfD8ynTHNzNrHOVNQQbOH4fzI9MJbZnZOolcemRboQdG31OJC3WZl5IIQtllkqmUxsB2Y7e6lbq8EfiaD5IJwPpN/Am4HCkt5PW7bqzoWjWjVtuS3hyBjYi3IOl8HOrh7N+Bd/v1NImxhbK8gFlJ0a4R04FHgb4lcuZnVB14Fbnb3vSVfjvJXErLNyskVyjZz9wJ37w60A3qbWZcSQ0LZXgFyJfwzaWbfBba7+4KyhkV5LibbqzoWjRyg+LeBdsCW0saYWS2gEfGfBik3l7vvcvfDkcWngV5xzhRUkG2acO6+99j0grvPAmqbWfNErNvMalP0i/lFd58eZUgo26y8XGFus8g69wDvA/1LvBTGZ7LcXCF9Ji8ABprZBoqmsS82s7+UGBO37VUdi8Z8IM3MTjOzFIoOEs0sMWYmMCzyeDAw1yNHlMLMVWLOeyBFc9LJYCZwbeSMoD5ArrtvDTuUmbU+No9rZr0p+v99VwLWa8CzwEp3/2MpwxK+zYLkCmObmVkLM2sceVwPuARYVWJYwj+TQXKF8Zl09zHu3s7dO1D0e2Kuu19TYljctletWLxJZeLu+Wb2M+Adis5Ymujuy83sXiDT3WdS9MF6wcyyKarOVydJrpvMbCCQH8k1PN65AMzsJYrOqmluZjnAryk6KIi7PwnMouhsoGzgADAiSXINBm40s3zgIHB1Aoo/FH0T/DGwNDIfDnAnkFosWxjbLEiuMLZZG2CymdWkqEhNcfc3wv5MBswVymcymkRtL10RLiIigVXH6SkRETlBKhoiIhKYioaIiASmoiEiIoGpaIiISGAqGiIiEpiKhoiIBKaiISIigf1/HtB4EIvpjHYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f781e53f0f0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "beta = np.sort(lrModel.coefficients)\n",
    "plt.plot(beta)\n",
    "plt.ylabel('Beta Coefficients')\n",
    "plt.show()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 425,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'LinearRegressionTrainingSummary' object has no attribute 'roc'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-425-99c9c7551146>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mtrainingSummary\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlrModel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mroc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainingSummary\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoPandas\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'FPR'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mroc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'TPR'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mylabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'False Positive Rate'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxlabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'True Positive Rate'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'LinearRegressionTrainingSummary' object has no attribute 'roc'"
     ]
    }
   ],
   "source": [
    "trainingSummary = lrModel.summary\n",
    "roc = trainingSummary.roc.toPandas()\n",
    "plt.plot(roc['FPR'],roc['TPR'])\n",
    "plt.ylabel('False Positive Rate')\n",
    "plt.xlabel('True Positive Rate')\n",
    "plt.title('ROC Curve')\n",
    "plt.show()\n",
    "print('Training set areaUnderROC: ' + str(trainingSummary.areaUnderROC))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pr = trainingSummary.pr.toPandas()\n",
    "plt.plot(pr['recall'],pr['precision'])\n",
    "plt.ylabel('Precision')\n",
    "plt.xlabel('Recall')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predictions = lrModel.transform(test)\n",
    "predictions.select('label', 'rawPrediction', 'prediction', 'probability').show(10)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test Data Evaluation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 426,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['features', 'label', 'rawPrediction', 'probability', 'prediction']"
      ]
     },
     "execution_count": 426,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions = lr_model.transform(test)\n",
    "model_predictions.columns\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------------------------------------+----------+\n",
      "|label|probability                             |prediction|\n",
      "+-----+----------------------------------------+----------+\n",
      "|1    |[0.7921142321252156,0.2078857678747844] |0.0       |\n",
      "|0    |[0.7332732992042076,0.2667267007957924] |0.0       |\n",
      "|0    |[0.7058440774094991,0.294155922590501]  |0.0       |\n",
      "|0    |[0.5801868096061101,0.41981319039388987]|0.0       |\n",
      "|1    |[0.5735471647731092,0.42645283522689076]|0.0       |\n",
      "|1    |[0.5601903397477959,0.4398096602522042] |0.0       |\n",
      "|0    |[0.5332326656582489,0.46676733434175105]|0.0       |\n",
      "|1    |[0.5332326656582489,0.46676733434175105]|0.0       |\n",
      "|1    |[0.5196706190017784,0.48032938099822164]|0.0       |\n",
      "|1    |[0.4585716734035034,0.5414283265964965] |1.0       |\n",
      "+-----+----------------------------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.select(['label','probability','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 427,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions = lr_model.evaluate(test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 428,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6592716617831729"
      ]
     },
     "execution_count": 428,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions.accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 429,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6627412414584508"
      ]
     },
     "execution_count": 429,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions.weightedPrecision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 430,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6592716617831729"
      ]
     },
     "execution_count": 430,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions.weightedRecall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 431,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.6208835341365462, 0.701048951048951]"
      ]
     },
     "execution_count": 431,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions.recallByLabel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 235,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The auc value of Logistic Regression Model is 0.7092938229110143\n",
      "The aupr value of Logistic Regression Model is 0.6630743130940658\n"
     ]
    }
   ],
   "source": [
    "# RMSE value of the model on test data \n",
    "lr_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "lr_auroc = lr_evaluator.evaluate(model_predictions)\n",
    "print(f'The auc value of Logistic Regression Model is {lr_auroc}')\n",
    "\n",
    "\n",
    "# RMSE value of the model on test data \n",
    "lr_evaluator = BinaryClassificationEvaluator(metricName='areaUnderPR')\n",
    "lr_aupr = lr_evaluator.evaluate(model_predictions)\n",
    "print(f'The aupr value of Logistic Regression Model is {lr_aupr}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "true_pos=model_predictions.filter(model_predictions['label']==1).filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 237,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "actual_pos=model_predictions.filter(model_predictions['label']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6701030927835051"
      ]
     },
     "execution_count": 238,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Recall \n",
    "float(true_pos)/(actual_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 239,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pred_pos=model_predictions.filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6478405315614618"
      ]
     },
     "execution_count": 240,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Precision\n",
    "float(true_pos)/(pred_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import DecisionTreeClassifier\n",
    "dt = DecisionTreeClassifier()\n",
    "dt_model = dt.fit(train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 242,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions = dt_model.transform(test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------------------------------------+----------+\n",
      "|label|probability                             |prediction|\n",
      "+-----+----------------------------------------+----------+\n",
      "|1    |[0.8089887640449438,0.19101123595505617]|0.0       |\n",
      "|0    |[0.8089887640449438,0.19101123595505617]|0.0       |\n",
      "|0    |[0.8089887640449438,0.19101123595505617]|0.0       |\n",
      "|0    |[0.3055555555555556,0.6944444444444444] |1.0       |\n",
      "|1    |[0.3055555555555556,0.6944444444444444] |1.0       |\n",
      "|1    |[0.3055555555555556,0.6944444444444444] |1.0       |\n",
      "|0    |[0.42972350230414746,0.5702764976958525]|1.0       |\n",
      "|1    |[0.42972350230414746,0.5702764976958525]|1.0       |\n",
      "|1    |[0.42972350230414746,0.5702764976958525]|1.0       |\n",
      "|1    |[0.42972350230414746,0.5702764976958525]|1.0       |\n",
      "+-----+----------------------------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.select(['label','probability','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The auc value of Decision Tree Classifier Model is 0.516199386190993\n",
      "The aupr value of Logistic Regression Model is 0.46771834172588167\n"
     ]
    }
   ],
   "source": [
    "# RMSE value of the model on test data \n",
    "dt_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "dt_auroc = dt_evaluator.evaluate(model_predictions)\n",
    "print(f'The auc value of Decision Tree Classifier Model is {dt_auroc}')\n",
    "\n",
    "\n",
    "# RMSE value of the model on test data \n",
    "dt_evaluator = BinaryClassificationEvaluator(metricName='areaUnderPR')\n",
    "dt_aupr = dt_evaluator.evaluate(model_predictions)\n",
    "print(f'The aupr value of Logistic Regression Model is {dt_aupr}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "true_pos=model_predictions.filter(model_predictions['label']==1).filter(model_predictions['prediction']==1).count()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "actual_pos=model_predictions.filter(model_predictions['label']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pred_pos=model_predictions.filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6907216494845361"
      ]
     },
     "execution_count": 248,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Recall \n",
    "float(true_pos)/(actual_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6661143330571665"
      ]
     },
     "execution_count": 249,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Precision on test Data \n",
    "float(true_pos)/(pred_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 267,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## RF \n",
    "\n",
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "rf = RandomForestClassifier(numTrees=50,maxDepth=30)\n",
    "rf_model = rf.fit(train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 268,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions = rf_model.transform(test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------------------------------------+----------+\n",
      "|label|probability                             |prediction|\n",
      "+-----+----------------------------------------+----------+\n",
      "|1    |[0.7791249643670785,0.2208750356329216] |0.0       |\n",
      "|0    |[0.7950008034402859,0.20499919655971408]|0.0       |\n",
      "|0    |[0.7753367882511633,0.22466321174883672]|0.0       |\n",
      "|0    |[0.4400825365852292,0.5599174634147708] |1.0       |\n",
      "|1    |[0.50177151730653,0.49822848269347]     |0.0       |\n",
      "|1    |[0.489823569376324,0.510176430623676]   |1.0       |\n",
      "|0    |[0.48614770456597733,0.5138522954340227]|1.0       |\n",
      "|1    |[0.48614770456597733,0.5138522954340227]|1.0       |\n",
      "|1    |[0.46395073486900773,0.5360492651309924]|1.0       |\n",
      "|1    |[0.46459640878219743,0.5354035912178025]|1.0       |\n",
      "+-----+----------------------------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.select(['label','probability','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The auc value of RandomForestClassifier Model is 0.7326433634020617\n"
     ]
    }
   ],
   "source": [
    "rf_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "rf_auroc = rf_evaluator.evaluate(model_predictions)\n",
    "print(f'The auc value of RandomForestClassifier Model is {rf_auroc}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 271,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The aupr value of RandomForestClassifier Model is 0.7277253895494864\n"
     ]
    }
   ],
   "source": [
    "rf_evaluator = BinaryClassificationEvaluator(metricName='areaUnderPR')\n",
    "rf_aupr = rf_evaluator.evaluate(model_predictions)\n",
    "print(f'The aupr value of RandomForestClassifier Model is {rf_aupr}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "true_pos=model_predictions.filter(model_predictions['label']==1).filter(model_predictions['prediction']==1).count()\n",
    "actual_pos=model_predictions.filter(model_predictions['label']==1).count()\n",
    "pred_pos=model_predictions.filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6666666666666666"
      ]
     },
     "execution_count": 273,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Recall \n",
    "float(true_pos)/(actual_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6672398968185727"
      ]
     },
     "execution_count": 274,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Precision on test Data \n",
    "float(true_pos)/(pred_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 276,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import GBTClassifier\n",
    "gbt = GBTClassifier()\n",
    "gbt_model = gbt.fit(train)\n",
    "model_predictions = gbt_model.transform(test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 277,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------------------------------------+----------+\n",
      "|label|probability                             |prediction|\n",
      "+-----+----------------------------------------+----------+\n",
      "|1    |[0.8118906324166552,0.18810936758334484]|0.0       |\n",
      "|0    |[0.8045448087863976,0.19545519121360244]|0.0       |\n",
      "|0    |[0.7451850305979334,0.2548149694020666] |0.0       |\n",
      "|0    |[0.30631795902192194,0.6936820409780781]|1.0       |\n",
      "|1    |[0.31352609339453574,0.6864739066054643]|1.0       |\n",
      "|1    |[0.3281403432684641,0.6718596567315359] |1.0       |\n",
      "|0    |[0.5130658281231842,0.48693417187681576]|0.0       |\n",
      "|1    |[0.5130658281231842,0.48693417187681576]|0.0       |\n",
      "|1    |[0.48461420867158717,0.5153857913284128]|1.0       |\n",
      "|1    |[0.4717208467161091,0.5282791532838909] |1.0       |\n",
      "+-----+----------------------------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.select(['label','probability','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The auc value of GradientBoostedTreesClassifier  is 0.7392410330756018\n"
     ]
    }
   ],
   "source": [
    "gbt_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "gbt_auroc = gbt_evaluator.evaluate(model_predictions)\n",
    "print(f'The auc value of GradientBoostedTreesClassifier  is {gbt_auroc}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The aupr value of GradientBoostedTreesClassifier Model is 0.7345982892755392\n"
     ]
    }
   ],
   "source": [
    "gbt_evaluator = BinaryClassificationEvaluator(metricName='areaUnderPR')\n",
    "gbt_aupr = gbt_evaluator.evaluate(model_predictions)\n",
    "print(f'The aupr value of GradientBoostedTreesClassifier Model is {gbt_aupr}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 280,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "true_pos=model_predictions.filter(model_predictions['label']==1).filter(model_predictions['prediction']==1).count()\n",
    "actual_pos=model_predictions.filter(model_predictions['label']==1).count()\n",
    "pred_pos=model_predictions.filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 281,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6683848797250859"
      ]
     },
     "execution_count": 281,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(actual_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 282,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6747614917606245"
      ]
     },
     "execution_count": 282,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(pred_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 284,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import LinearSVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 287,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lsvc = LinearSVC()\n",
    "# Fit the model\n",
    "lsvc_model = lsvc.fit(train)\n",
    "model_predictions = lsvc_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['features', 'label', 'rawPrediction', 'prediction']"
      ]
     },
     "execution_count": 288,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+\n",
      "|label|prediction|\n",
      "+-----+----------+\n",
      "|1    |0.0       |\n",
      "|0    |0.0       |\n",
      "|0    |0.0       |\n",
      "|0    |1.0       |\n",
      "|1    |1.0       |\n",
      "|1    |1.0       |\n",
      "|0    |1.0       |\n",
      "|1    |1.0       |\n",
      "|1    |1.0       |\n",
      "|1    |1.0       |\n",
      "+-----+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.select(['label','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 290,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The auc value of SupportVectorClassifier  is 0.7043772749366973\n"
     ]
    }
   ],
   "source": [
    "svc_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "svc_auroc = svc_evaluator.evaluate(model_predictions)\n",
    "print(f'The auc value of SupportVectorClassifier  is {svc_auroc}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 291,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The aupr value of GradientBoostedTreesClassifier Model is 0.6567277377856992\n"
     ]
    }
   ],
   "source": [
    "svc_evaluator = BinaryClassificationEvaluator(metricName='areaUnderPR')\n",
    "svc_aupr =svc_evaluator.evaluate(model_predictions)\n",
    "print(f'The aupr value of GradientBoostedTreesClassifier Model is {svc_aupr}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 292,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "true_pos=model_predictions.filter(model_predictions['label']==1).filter(model_predictions['prediction']==1).count()\n",
    "actual_pos=model_predictions.filter(model_predictions['label']==1).count()\n",
    "pred_pos=model_predictions.filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 293,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7774914089347079"
      ]
     },
     "execution_count": 293,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(actual_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.600132625994695"
      ]
     },
     "execution_count": 294,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(pred_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 296,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import NaiveBayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 297,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "nb = NaiveBayes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 298,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "nb_model = nb.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 299,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_predictions = nb_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['features', 'label', 'rawPrediction', 'probability', 'prediction']"
      ]
     },
     "execution_count": 300,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_predictions.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 301,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------------------------------------+----------+\n",
      "|label|probability                             |prediction|\n",
      "+-----+----------------------------------------+----------+\n",
      "|1    |[0.48630071280188164,0.5136992871981184]|1.0       |\n",
      "|0    |[0.49911662267320056,0.5008833773267994]|1.0       |\n",
      "|0    |[0.5044578946407129,0.49554210535928717]|0.0       |\n",
      "|0    |[0.40355230543479254,0.5964476945652075]|1.0       |\n",
      "|1    |[0.4045812569659479,0.5954187430340521] |1.0       |\n",
      "|1    |[0.4066416699033005,0.5933583300966995] |1.0       |\n",
      "|0    |[0.41077229948270594,0.589227700517294] |1.0       |\n",
      "|1    |[0.41077229948270594,0.589227700517294] |1.0       |\n",
      "|1    |[0.41284238036891824,0.5871576196310817]|1.0       |\n",
      "|1    |[0.42219493600486624,0.5778050639951338]|1.0       |\n",
      "+-----+----------------------------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model_predictions.select(['label','probability','prediction']).show(10,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 303,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The auc value of NB Classifier  is 0.43543736717760884\n"
     ]
    }
   ],
   "source": [
    "nb_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "nb_auroc = nb_evaluator.evaluate(model_predictions)\n",
    "print(f'The auc value of NB Classifier  is {nb_auroc}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 304,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The aupr value of NB Classifier Model is 0.4321001351769349\n"
     ]
    }
   ],
   "source": [
    "nb_evaluator = BinaryClassificationEvaluator(metricName='areaUnderPR')\n",
    "nb_aupr =nb_evaluator.evaluate(model_predictions)\n",
    "print(f'The aupr value of NB Classifier Model is {nb_aupr}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 305,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "true_pos=model_predictions.filter(model_predictions['label']==1).filter(model_predictions['prediction']==1).count()\n",
    "actual_pos=model_predictions.filter(model_predictions['label']==1).count()\n",
    "pred_pos=model_predictions.filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 306,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5867697594501718"
      ]
     },
     "execution_count": 306,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(actual_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6254578754578755"
      ]
     },
     "execution_count": 307,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(pred_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rf = RandomForestClassifier()\n",
    "rf_model = rf.fit(train)\n",
    "model_predictions = rf_model.transform(test)\n",
    "rf_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "rf_auroc = rf_evaluator.evaluate(model_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 321,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.tuning import ParamGridBuilder, CrossValidator\n",
    "evaluator = BinaryClassificationEvaluator()\n",
    "\n",
    "rf = RandomForestClassifier()\n",
    "paramGrid = (ParamGridBuilder()\n",
    "             .addGrid(rf.maxDepth, [5,10,20,25,30])\n",
    "             .addGrid(rf.maxBins, [20,30,40,50 ])\n",
    "             .addGrid(rf.numTrees, [5, 20,50,100])\n",
    "             .build())\n",
    "cv = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=rf_evaluator, numFolds=10)\n",
    "cv_model = cv.fit(train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 322,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "best_rf_model = cv_model.bestModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 323,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Generate predictions for entire dataset\n",
    "model_predictions = best_rf_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Evaluate best model\n",
    "rf_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')\n",
    "rf_auroc = rf_evaluator.evaluate(model_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7425990374615659"
      ]
     },
     "execution_count": 325,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rf_auroc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 326,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "true_pos=model_predictions.filter(model_predictions['label']==1).filter(model_predictions['prediction']==1).count()\n",
    "actual_pos=model_predictions.filter(model_predictions['label']==1).count()\n",
    "pred_pos=model_predictions.filter(model_predictions['prediction']==1).count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 327,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6520618556701031"
      ]
     },
     "execution_count": 327,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(actual_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 328,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6825539568345323"
      ]
     },
     "execution_count": 328,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "float(true_pos)/(pred_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
